diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index 7413ad9ff2..305594db82 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -49,6 +49,7 @@ import { STATIC_MAINNET_TOKEN_LIST, TokenDetectionController, controllerName, + mapChainIdWithTokenListMap, } from './TokenDetectionController'; import { getDefaultTokenListState, @@ -329,6 +330,7 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: defaultSelectedAccount, @@ -363,13 +365,17 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getAccount: selectedAccount, getSelectedAccount: selectedAccount, }, }, + + // Salim .... async ({ controller, mockTokenListGetState, callActionSpy }) => { + mockMultiChainAccountsService(); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -404,6 +410,80 @@ describe('TokenDetectionController', () => { ); }); + it('should not call add tokens if balance is not available on account api', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [sampleTokenA.address]: new BN(1), + }); + + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + await withController( + { + options: { + getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, + }, + }, + + async ({ controller, mockTokenListGetState, callActionSpy }) => { + mockMultiChainAccountsService(); + + const mockAPI = mockMultiChainAccountsService(); + mockAPI.mockFetchMultiChainBalances.mockResolvedValue({ + count: 0, + balances: [ + { + object: 'token', + address: '0xaddress', + name: 'Mock Token', + symbol: 'MOCK', + decimals: 18, + balance: '10.18', + chainId: 2, + }, + ], + unprocessedNetworks: [], + }); + + mockTokenListGetState({ + ...getDefaultTokenListState(), + tokensChainsCache: { + '0x1': { + timestamp: 0, + data: { + test: { + name: sampleTokenA.name, + symbol: sampleTokenA.symbol, + decimals: sampleTokenA.decimals, + address: 'test', + occurrences: 1, + aggregators: sampleTokenA.aggregators, + iconUrl: sampleTokenA.image, + }, + }, + }, + }, + }); + + await controller.start(); + + expect(callActionSpy).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + [sampleTokenA], + { + chainId: ChainId.mainnet, + selectedAddress: selectedAccount.address, + }, + ); + }, + ); + }); + it('should detect tokens correctly on the Polygon network', async () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), @@ -415,6 +495,7 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getAccount: selectedAccount, @@ -428,6 +509,7 @@ describe('TokenDetectionController', () => { mockGetNetworkClientById, callActionSpy, }) => { + mockMultiChainAccountsService(); mockNetworkState({ ...getDefaultNetworkControllerState(), selectedNetworkClientId: 'polygon', @@ -494,6 +576,7 @@ describe('TokenDetectionController', () => { }, }, async ({ controller, mockTokenListGetState, callActionSpy }) => { + mockMultiChainAccountsService(); const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { @@ -551,6 +634,7 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getAccount: selectedAccount, @@ -563,6 +647,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState, callActionSpy, }) => { + mockMultiChainAccountsService(); mockTokensGetState({ ...getDefaultTokensState(), ignoredTokens: [sampleTokenA.address], @@ -604,12 +689,14 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: defaultSelectedAccount, }, }, async ({ controller, mockTokenListGetState, callActionSpy }) => { + mockMultiChainAccountsService(); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -666,6 +753,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: firstSelectedAccount, @@ -677,6 +765,7 @@ describe('TokenDetectionController', () => { triggerSelectedAccountChange, callActionSpy, }) => { + mockMultiChainAccountsService(); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -725,6 +814,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -735,6 +825,7 @@ describe('TokenDetectionController', () => { triggerSelectedAccountChange, callActionSpy, }) => { + mockMultiChainAccountsService(); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -844,6 +935,7 @@ describe('TokenDetectionController', () => { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: firstSelectedAccount, @@ -914,6 +1006,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: firstSelectedAccount, @@ -926,6 +1019,7 @@ describe('TokenDetectionController', () => { triggerSelectedAccountChange, callActionSpy, }) => { + mockMultiChainAccountsService(); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -978,6 +1072,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -989,6 +1084,7 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange, callActionSpy, }) => { + mockMultiChainAccountsService(); mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), @@ -1049,6 +1145,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: firstSelectedAccount, @@ -1061,6 +1158,7 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange, callActionSpy, }) => { + mockMultiChainAccountsService(); mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), @@ -1422,67 +1520,6 @@ describe('TokenDetectionController', () => { }); describe('when "disabled" is false', () => { - it('should detect new tokens after switching network client id', async () => { - const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ - [sampleTokenA.address]: new BN(1), - }); - const selectedAccount = createMockInternalAccount({ - address: '0x0000000000000000000000000000000000000001', - }); - await withController( - { - options: { - disabled: false, - getBalancesInSingleCall: mockGetBalancesInSingleCall, - }, - mocks: { - getAccount: selectedAccount, - getSelectedAccount: selectedAccount, - }, - }, - async ({ - mockTokenListGetState, - callActionSpy, - triggerNetworkDidChange, - }) => { - mockTokenListGetState({ - ...getDefaultTokenListState(), - tokensChainsCache: { - '0x89': { - timestamp: 0, - data: { - [sampleTokenA.address]: { - name: sampleTokenA.name, - symbol: sampleTokenA.symbol, - decimals: sampleTokenA.decimals, - address: sampleTokenA.address, - occurrences: 1, - aggregators: sampleTokenA.aggregators, - iconUrl: sampleTokenA.image, - }, - }, - }, - }, - }); - - triggerNetworkDidChange({ - ...getDefaultNetworkControllerState(), - selectedNetworkClientId: 'polygon', - }); - await advanceTime({ clock, duration: 1 }); - - expect(callActionSpy).toHaveBeenCalledWith( - 'TokensController:addDetectedTokens', - [sampleTokenA], - { - chainId: '0x89', - selectedAddress: selectedAccount.address, - }, - ); - }, - ); - }); - it('should not detect new tokens after switching to a chain that does not support token detection', async () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), @@ -1737,6 +1774,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -1748,6 +1786,7 @@ describe('TokenDetectionController', () => { callActionSpy, triggerTokenListStateChange, }) => { + mockMultiChainAccountsService(); const tokenList = { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1951,6 +1990,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -1962,6 +2002,7 @@ describe('TokenDetectionController', () => { triggerTokenListStateChange, controller, }) => { + mockMultiChainAccountsService(); const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { @@ -2010,6 +2051,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2021,6 +2063,7 @@ describe('TokenDetectionController', () => { triggerTokenListStateChange, controller, }) => { + mockMultiChainAccountsService(); const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { @@ -2087,6 +2130,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2098,6 +2142,7 @@ describe('TokenDetectionController', () => { triggerTokenListStateChange, controller, }) => { + mockMultiChainAccountsService(); const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { @@ -2208,33 +2253,33 @@ describe('TokenDetectionController', () => { }); controller.startPolling({ - networkClientId: 'mainnet', + chainIds: ['0x1'], address: '0x1', }); controller.startPolling({ - networkClientId: 'sepolia', + chainIds: ['0xaa36a7'], address: '0xdeadbeef', }); controller.startPolling({ - networkClientId: 'goerli', + chainIds: ['0x5'], address: '0x3', }); await advanceTime({ clock, duration: 0 }); expect(spy.mock.calls).toMatchObject([ - [{ networkClientId: 'mainnet', selectedAddress: '0x1' }], - [{ networkClientId: 'sepolia', selectedAddress: '0xdeadbeef' }], - [{ networkClientId: 'goerli', selectedAddress: '0x3' }], + [{ chainIds: ['0x1'], selectedAddress: '0x1' }], + [{ chainIds: ['0xaa36a7'], selectedAddress: '0xdeadbeef' }], + [{ chainIds: ['0x5'], selectedAddress: '0x3' }], ]); await advanceTime({ clock, duration: DEFAULT_INTERVAL }); expect(spy.mock.calls).toMatchObject([ - [{ networkClientId: 'mainnet', selectedAddress: '0x1' }], - [{ networkClientId: 'sepolia', selectedAddress: '0xdeadbeef' }], - [{ networkClientId: 'goerli', selectedAddress: '0x3' }], - [{ networkClientId: 'mainnet', selectedAddress: '0x1' }], - [{ networkClientId: 'sepolia', selectedAddress: '0xdeadbeef' }], - [{ networkClientId: 'goerli', selectedAddress: '0x3' }], + [{ chainIds: ['0x1'], selectedAddress: '0x1' }], + [{ chainIds: ['0xaa36a7'], selectedAddress: '0xdeadbeef' }], + [{ chainIds: ['0x5'], selectedAddress: '0x3' }], + [{ chainIds: ['0x1'], selectedAddress: '0x1' }], + [{ chainIds: ['0xaa36a7'], selectedAddress: '0xdeadbeef' }], + [{ chainIds: ['0x5'], selectedAddress: '0x3' }], ]); }, ); @@ -2254,6 +2299,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2266,6 +2312,7 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange, callActionSpy, }) => { + mockMultiChainAccountsService(); mockNetworkState({ ...getDefaultNetworkControllerState(), selectedNetworkClientId: NetworkType.goerli, @@ -2275,7 +2322,7 @@ describe('TokenDetectionController', () => { useTokenDetection: false, }); await controller.detectTokens({ - networkClientId: NetworkType.goerli, + chainIds: ['0x5'], selectedAddress: selectedAccount.address, }); expect(callActionSpy).not.toHaveBeenCalledWith( @@ -2303,6 +2350,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + // useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2314,12 +2362,13 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange, callActionSpy, }) => { + mockMultiChainAccountsService(); triggerPreferencesStateChange({ ...getDefaultPreferencesState(), useTokenDetection: false, }); await controller.detectTokens({ - networkClientId: NetworkType.mainnet, + chainIds: ['0x1'], selectedAddress: selectedAccount.address, }); expect(callActionSpy).toHaveBeenLastCalledWith( @@ -2353,6 +2402,7 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2360,6 +2410,7 @@ describe('TokenDetectionController', () => { }, }, async ({ controller, mockTokenListGetState, callActionSpy }) => { + mockMultiChainAccountsService(); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -2381,7 +2432,7 @@ describe('TokenDetectionController', () => { }); await controller.detectTokens({ - networkClientId: NetworkType.mainnet, + chainIds: ['0x1'], selectedAddress: selectedAccount.address, }); @@ -2412,6 +2463,7 @@ describe('TokenDetectionController', () => { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, trackMetaMetricsEvent: mockTrackMetaMetricsEvent, + useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2419,6 +2471,7 @@ describe('TokenDetectionController', () => { }, }, async ({ controller, mockTokenListGetState }) => { + mockMultiChainAccountsService(); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { @@ -2440,7 +2493,7 @@ describe('TokenDetectionController', () => { }); await controller.detectTokens({ - networkClientId: NetworkType.mainnet, + chainIds: ['0x1'], selectedAddress: selectedAccount.address, }); @@ -2474,6 +2527,7 @@ describe('TokenDetectionController', () => { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, trackMetaMetricsEvent: mockTrackMetaMetricsEvent, + useAccountsAPI: true, // USING ACCOUNTS API }, }, async ({ @@ -2482,6 +2536,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState, callActionSpy, }) => { + mockMultiChainAccountsService(); // @ts-expect-error forcing an undefined value mockGetAccount(undefined); mockTokenListGetState({ @@ -2505,7 +2560,7 @@ describe('TokenDetectionController', () => { }); await controller.detectTokens({ - networkClientId: NetworkType.mainnet, + chainIds: ['0x1'], }); expect(callActionSpy).toHaveBeenLastCalledWith( @@ -2540,6 +2595,54 @@ describe('TokenDetectionController', () => { ); }); + it('should fallback to rpc call', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [sampleTokenA.address]: new BN(1), + }); + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, + useAccountsAPI: true, // USING ACCOUNTS API + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ + controller, + mockNetworkState, + triggerPreferencesStateChange, + callActionSpy, + }) => { + const mockAPI = mockMultiChainAccountsService(); + mockAPI.mockFetchMultiChainBalances.mockRejectedValue( + new Error('Mock Error'), + ); + mockNetworkState({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'polygon', + }); + triggerPreferencesStateChange({ + ...getDefaultPreferencesState(), + useTokenDetection: false, + }); + await controller.detectTokens({ + chainIds: ['0x5'], + selectedAddress: selectedAccount.address, + }); + expect(callActionSpy).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); + }, + ); + }); + /** * Test Utility - Arrange and Act `detectTokens()` with the Accounts API feature * RPC flow will return `sampleTokenA` and the Accounts API flow will use `sampleTokenB` @@ -2634,7 +2737,7 @@ describe('TokenDetectionController', () => { // Act await controller.detectTokens({ - networkClientId: NetworkType.mainnet, + chainIds: ['0x1'], selectedAddress: selectedAccount.address, }); @@ -2752,6 +2855,57 @@ describe('TokenDetectionController', () => { assertTokensNeverAdded(); }); }); + + describe('mapChainIdWithTokenListMap', () => { + it('should return an empty object when given an empty input', () => { + const tokensChainsCache = {}; + const result = mapChainIdWithTokenListMap(tokensChainsCache); + expect(result).toStrictEqual({}); + }); + + it('should return the same structure when there is no "data" property in the object', () => { + const tokensChainsCache = { + chain1: { info: 'no data property' }, + }; + const result = mapChainIdWithTokenListMap(tokensChainsCache); + expect(result).toStrictEqual(tokensChainsCache); // Expect unchanged structure + }); + + it('should map "data" property if present in the object', () => { + const tokensChainsCache = { + chain1: { data: 'someData' }, + }; + const result = mapChainIdWithTokenListMap(tokensChainsCache); + expect(result).toStrictEqual({ chain1: 'someData' }); + }); + + it('should handle multiple chains with mixed "data" properties', () => { + const tokensChainsCache = { + chain1: { data: 'someData1' }, + chain2: { info: 'no data property' }, + chain3: { data: 'someData3' }, + }; + const result = mapChainIdWithTokenListMap(tokensChainsCache); + + expect(result).toStrictEqual({ + chain1: 'someData1', + chain2: { info: 'no data property' }, + chain3: 'someData3', + }); + }); + + it('should handle nested object with "data" property correctly', () => { + const tokensChainsCache = { + chain1: { + data: { + nested: 'nestedData', + }, + }, + }; + const result = mapChainIdWithTokenListMap(tokensChainsCache); + expect(result).toStrictEqual({ chain1: { nested: 'nestedData' } }); + }); + }); }); /** diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index 52a84bfc01..a2d9a744c1 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -70,6 +70,11 @@ type TokenDetectionMap = { [P in keyof TokenListMap]: Omit; }; +type NetworkClient = { + chainId: Hex; + networkClientId: string; +}; + export const STATIC_MAINNET_TOKEN_LIST = Object.entries( contractMap, ).reduce((acc, [base, contract]) => { @@ -90,7 +95,9 @@ export const STATIC_MAINNET_TOKEN_LIST = Object.entries( * @param tokensChainsCache - TokensChainsCache input object * @returns returns the map of chainId with TokenListMap */ -function mapChainIdWithTokenListMap(tokensChainsCache: TokensChainsCache) { +export function mapChainIdWithTokenListMap( + tokensChainsCache: TokensChainsCache, +) { return mapValues(tokensChainsCache, (value) => { if (isObject(value) && 'data' in value) { return get(value, ['data']); @@ -147,7 +154,7 @@ export type TokenDetectionControllerMessenger = RestrictedControllerMessenger< /** The input to start polling for the {@link TokenDetectionController} */ type TokenDetectionPollingInput = { - networkClientId: NetworkClientId; + chainIds: Hex[]; address: string; }; @@ -219,25 +226,27 @@ export class TokenDetectionController extends StaticIntervalPollingController hexToNumber(chainId)); - if (!supportedNetworks || !supportedNetworks.includes(chainIdNumber)) { + if ( + !supportedNetworks || + !chainIdNumbers.every((id) => supportedNetworks.includes(id)) + ) { const supportedNetworksErrStr = (supportedNetworks ?? []).toString(); throw new Error( - `Unsupported Network: supported networks ${supportedNetworksErrStr}, network: ${chainIdNumber}`, + `Unsupported Network: supported networks ${supportedNetworksErrStr}, requested networks: ${chainIdNumbers.toString()}`, ); } const result = await fetchMultiChainBalances( address, { - networks: [chainIdNumber], + networks: chainIdNumbers, }, this.platform, ); @@ -306,6 +315,7 @@ export class TokenDetectionController extends StaticIntervalPollingController { - const isNetworkClientIdChanged = - this.#networkClientId !== selectedNetworkClientId; - - const { chainId: newChainId } = - this.#getCorrectChainIdAndNetworkClientId(selectedNetworkClientId); - this.#isDetectionEnabledForNetwork = - isTokenDetectionSupportedForNetwork(newChainId); - - if (isNetworkClientIdChanged && this.#isDetectionEnabledForNetwork) { - this.#networkClientId = selectedNetworkClientId; - await this.#restartTokenDetection({ - networkClientId: this.#networkClientId, - }); - } - }, - ); } /** @@ -501,22 +489,38 @@ export class TokenDetectionController extends StaticIntervalPollingController { + const configuration = networkConfigurationsByChainId[chainId]; + return { + chainId, + networkClientId: + configuration.rpcEndpoints[configuration.defaultRpcEndpointIndex] + .networkClientId, + }; + }); + } + + #getCorrectChainIdAndNetworkClientId() { const { selectedNetworkClientId } = this.messagingSystem.call( 'NetworkController:getState', ); @@ -533,14 +537,14 @@ export class TokenDetectionController extends StaticIntervalPollingController { if (!this.isActive) { return; } await this.detectTokens({ - networkClientId, + chainIds, selectedAddress: address, }); } @@ -551,93 +555,178 @@ export class TokenDetectionController extends StaticIntervalPollingController { await this.detectTokens({ - networkClientId, + chainIds, selectedAddress, }); this.setIntervalLength(DEFAULT_INTERVAL); } + #getChainsToDetect( + clientNetworks: NetworkClient[], + supportedNetworks: number[] | null | undefined, + ) { + const chainsToDetectUsingAccountAPI: Hex[] = []; + const chainsToDetectUsingRpc: NetworkClient[] = []; + + clientNetworks.forEach(({ chainId, networkClientId }) => { + if (supportedNetworks?.includes(hexToNumber(chainId))) { + chainsToDetectUsingAccountAPI.push(chainId); + } else { + chainsToDetectUsingRpc.push({ chainId, networkClientId }); + } + }); + + return { chainsToDetectUsingRpc, chainsToDetectUsingAccountAPI }; + } + + async #attemptAccountAPIDetection( + chainsToDetectUsingAccountAPI: Hex[], + addressToDetect: string, + supportedNetworks: number[] | null, + ) { + return await this.#addDetectedTokensViaAPI({ + chainIds: chainsToDetectUsingAccountAPI, + selectedAddress: addressToDetect, + supportedNetworks, + }); + } + + #addChainsToRpcDetection( + chainsToDetectUsingRpc: NetworkClient[], + chainsToDetectUsingAccountAPI: Hex[], + clientNetworks: NetworkClient[], + ): void { + chainsToDetectUsingAccountAPI.forEach((chainId) => { + const networkEntry = clientNetworks.find( + (network) => network.chainId === chainId, + ); + if (networkEntry) { + chainsToDetectUsingRpc.push({ + chainId: networkEntry.chainId, + networkClientId: networkEntry.networkClientId, + }); + } + }); + } + + #shouldDetectTokens(chainId: Hex): boolean { + if (!isTokenDetectionSupportedForNetwork(chainId)) { + return false; + } + if ( + !this.#isDetectionEnabledFromPreferences && + chainId !== ChainId.mainnet + ) { + return false; + } + + const isMainnetDetectionInactive = + !this.#isDetectionEnabledFromPreferences && chainId === ChainId.mainnet; + if (isMainnetDetectionInactive) { + this.#tokensChainsCache = this.#getConvertedStaticMainnetTokenList(); + } else { + const { tokensChainsCache } = this.messagingSystem.call( + 'TokenListController:getState', + ); + this.#tokensChainsCache = tokensChainsCache ?? {}; + } + + return true; + } + + async #detectTokensUsingRpc( + chainsToDetectUsingRpc: NetworkClient[], + addressToDetect: string, + ): Promise { + for (const { chainId, networkClientId } of chainsToDetectUsingRpc) { + if (!this.#shouldDetectTokens(chainId)) { + continue; + } + + const tokenCandidateSlices = this.#getSlicesOfTokensToDetect({ + chainId, + selectedAddress: addressToDetect, + }); + const tokenDetectionPromises = tokenCandidateSlices.map((tokensSlice) => + this.#addDetectedTokens({ + tokensSlice, + selectedAddress: addressToDetect, + networkClientId, + chainId, + }), + ); + + await Promise.all(tokenDetectionPromises); + } + } + /** * For each token in the token list provided by the TokenListController, checks the token's balance for the selected account address on the active network. * On mainnet, if token detection is disabled in preferences, ERC20 token auto detection will be triggered for each contract address in the legacy token list from the @metamask/contract-metadata repo. * * @param options - Options for token detection. - * @param options.networkClientId - The ID of the network client to use. + * @param options.chainIds - The chain IDs of the network client to use. * @param options.selectedAddress - the selectedAddress against which to detect for token balances. */ async detectTokens({ - networkClientId, + chainIds, selectedAddress, }: { - networkClientId?: NetworkClientId; + chainIds?: Hex[]; selectedAddress?: string; } = {}): Promise { if (!this.isActive) { return; } - const addressAgainstWhichToDetect = - selectedAddress ?? this.#getSelectedAddress(); - const { chainId, networkClientId: selectedNetworkClientId } = - this.#getCorrectChainIdAndNetworkClientId(networkClientId); - const chainIdAgainstWhichToDetect = chainId; - const networkClientIdAgainstWhichToDetect = selectedNetworkClientId; + const addressToDetect = selectedAddress ?? this.#getSelectedAddress(); + const clientNetworks = this.#getCorrectNetworkClientIdByChainId(chainIds); - if (!isTokenDetectionSupportedForNetwork(chainIdAgainstWhichToDetect)) { - return; + let supportedNetworks; + if (this.#accountsAPI.isAccountsAPIEnabled) { + supportedNetworks = await this.#accountsAPI.getSupportedNetworks(); } - if ( - !this.#isDetectionEnabledFromPreferences && - chainIdAgainstWhichToDetect !== ChainId.mainnet - ) { - return; - } - const isTokenDetectionInactiveInMainnet = - !this.#isDetectionEnabledFromPreferences && - chainIdAgainstWhichToDetect === ChainId.mainnet; - const { tokensChainsCache } = this.messagingSystem.call( - 'TokenListController:getState', - ); - this.#tokensChainsCache = isTokenDetectionInactiveInMainnet - ? this.#getConvertedStaticMainnetTokenList() - : tokensChainsCache ?? {}; + const { chainsToDetectUsingRpc, chainsToDetectUsingAccountAPI } = + this.#getChainsToDetect(clientNetworks, supportedNetworks); + + // Try detecting tokens via Account API first if conditions allow + if (supportedNetworks && chainsToDetectUsingAccountAPI.length > 0) { + const apiResult = await this.#attemptAccountAPIDetection( + chainsToDetectUsingAccountAPI, + addressToDetect, + supportedNetworks, + ); - const tokenCandidateSlices = this.#getSlicesOfTokensToDetect({ - chainId: chainIdAgainstWhichToDetect, - selectedAddress: addressAgainstWhichToDetect, - }); + // If API succeeds and no chains are left for RPC detection, we can return early + if ( + apiResult?.result === 'success' && + chainsToDetectUsingRpc.length === 0 + ) { + return; + } - // Attempt Accounts API Detection - const accountAPIResult = await this.#addDetectedTokensViaAPI({ - chainId: chainIdAgainstWhichToDetect, - selectedAddress: addressAgainstWhichToDetect, - tokenCandidateSlices, - }); - if (accountAPIResult?.result === 'success') { - return; + // If API fails or chainsToDetectUsingRpc still has items, add chains to RPC detection + this.#addChainsToRpcDetection( + chainsToDetectUsingRpc, + chainsToDetectUsingAccountAPI, + clientNetworks, + ); } - // Attempt RPC Detection - const tokenDetectionPromises = tokenCandidateSlices.map((tokensSlice) => - this.#addDetectedTokens({ - tokensSlice, - selectedAddress: addressAgainstWhichToDetect, - networkClientId: networkClientIdAgainstWhichToDetect, - chainId: chainIdAgainstWhichToDetect, - }), - ); - - await Promise.all(tokenDetectionPromises); + // Proceed with RPC detection if there are chains remaining in chainsToDetectUsingRpc + if (chainsToDetectUsingRpc.length > 0) { + await this.#detectTokensUsingRpc(chainsToDetectUsingRpc, addressToDetect); + } } #getSlicesOfTokensToDetect({ @@ -714,91 +803,160 @@ export class TokenDetectionController extends StaticIntervalPollingController { - const tokenBalances = await this.#accountsAPI - .getMultiChainBalances(selectedAddress, chainId) + // Fetch balances for multiple chain IDs at once + const tokenBalancesByChain = await this.#accountsAPI + .getMultiNetworksBalances(selectedAddress, chainIds, supportedNetworks) .catch(() => null); - if (!tokenBalances || tokenBalances.length === 0) { + if ( + !tokenBalancesByChain || + Object.keys(tokenBalancesByChain).length === 0 + ) { return { result: 'failed' } as const; } - const tokensWithBalance: Token[] = []; - const eventTokensDetails: string[] = []; - - const tokenCandidateSet = new Set(tokenCandidateSlices.flat()); + // Process each chain ID individually + for (const chainId of chainIds) { + const isTokenDetectionInactiveInMainnet = + !this.#isDetectionEnabledFromPreferences && + chainId === ChainId.mainnet; + const { tokensChainsCache } = this.messagingSystem.call( + 'TokenListController:getState', + ); + this.#tokensChainsCache = isTokenDetectionInactiveInMainnet + ? this.#getConvertedStaticMainnetTokenList() + : tokensChainsCache ?? {}; + + // Generate token candidates based on chainId and selectedAddress + const tokenCandidateSlices = this.#getSlicesOfTokensToDetect({ + chainId, + selectedAddress, + }); - tokenBalances.forEach((token) => { - const tokenAddress = token.address; + // Filter balances for the current chainId + const tokenBalances = tokenBalancesByChain.filter( + (balance) => balance.chainId === hexToNumber(chainId), + ); - // Make sure that the token to add is in our candidate list - // Ensures we don't add tokens we already own - if (!tokenCandidateSet.has(token.address)) { - return; + if (!tokenBalances || tokenBalances.length === 0) { + continue; } - // We need specific data from tokensChainsCache to correctly create a token - // So even if we have a token that was detected correctly by the API, if its missing data we cannot safely add it. - if (!this.#tokensChainsCache[chainId].data[token.address]) { - return; + // Use helper function to filter tokens with balance for this chainId + const { tokensWithBalance, eventTokensDetails } = + this.#filterAndBuildTokensWithBalance( + tokenCandidateSlices, + tokenBalances, + chainId, + ); + + if (tokensWithBalance.length) { + this.#trackMetaMetricsEvent({ + event: 'Token Detected', + category: 'Wallet', + properties: { + tokens: eventTokensDetails, + // TODO: Either fix this lint violation or explain why it's necessary to ignore. + // eslint-disable-next-line @typescript-eslint/naming-convention + token_standard: ERC20, + // TODO: Either fix this lint violation or explain why it's necessary to ignore. + // eslint-disable-next-line @typescript-eslint/naming-convention + asset_type: ASSET_TYPES.TOKEN, + }, + }); + + await this.messagingSystem.call( + 'TokensController:addDetectedTokens', + tokensWithBalance, + { + selectedAddress, + chainId, + }, + ); } + } - const { decimals, symbol, aggregators, iconUrl, name } = - this.#tokensChainsCache[chainId].data[token.address]; - eventTokensDetails.push(`${symbol} - ${tokenAddress}`); - tokensWithBalance.push({ - address: tokenAddress, - decimals, - symbol, - aggregators, - image: iconUrl, - isERC721: false, - name, - }); - }); + return { result: 'success' } as const; + }); + } - if (tokensWithBalance.length) { - this.#trackMetaMetricsEvent({ - event: 'Token Detected', - category: 'Wallet', - properties: { - tokens: eventTokensDetails, - // TODO: Either fix this lint violation or explain why it's necessary to ignore. - // eslint-disable-next-line @typescript-eslint/naming-convention - token_standard: ERC20, - // TODO: Either fix this lint violation or explain why it's necessary to ignore. - // eslint-disable-next-line @typescript-eslint/naming-convention - asset_type: ASSET_TYPES.TOKEN, - }, - }); + /** + * Helper function to filter and build token data for detected tokens + * @param options.tokenCandidateSlices - these are tokens we know a user does not have (by checking the tokens controller). + * We will use these these token candidates to determine if a token found from the API is valid to be added on the users wallet. + * It will also prevent us to adding tokens a user already has + * @param tokenBalances - Tokens balances fetched from API + * @param chainId - The chain ID being processed + * @returns an object containing tokensWithBalance and eventTokensDetails arrays + */ - await this.messagingSystem.call( - 'TokensController:addDetectedTokens', - tokensWithBalance, - { - selectedAddress, - chainId, - }, - ); + #filterAndBuildTokensWithBalance( + tokenCandidateSlices: string[][], + tokenBalances: + | { + object: string; + type?: string; + timestamp?: string; + address: string; + symbol: string; + name: string; + decimals: number; + chainId: number; + balance: string; + }[] + | null, + chainId: Hex, + ) { + const tokensWithBalance: Token[] = []; + const eventTokensDetails: string[] = []; + + const tokenCandidateSet = new Set(tokenCandidateSlices.flat()); + + tokenBalances?.forEach((token) => { + const tokenAddress = token.address; + + // Make sure the token to add is in our candidate list + if (!tokenCandidateSet.has(tokenAddress)) { + return; } - return { result: 'success' } as const; + // Retrieve token data from cache to safely add it + const tokenData = this.#tokensChainsCache[chainId]?.data[tokenAddress]; + + // We need specific data from tokensChainsCache to correctly create a token + // So even if we have a token that was detected correctly by the API, if its missing data we cannot safely add it. + if (!tokenData) { + return; + } + + const { decimals, symbol, aggregators, iconUrl, name } = tokenData; + eventTokensDetails.push(`${symbol} - ${tokenAddress}`); + tokensWithBalance.push({ + address: tokenAddress, + decimals, + symbol, + aggregators, + image: iconUrl, + isERC721: false, + name, + }); }); + + return { tokensWithBalance, eventTokensDetails }; } async #addDetectedTokens({ diff --git a/packages/assets-controllers/src/TokensController.test.ts b/packages/assets-controllers/src/TokensController.test.ts index eb12ef587d..22f7c60a1a 100644 --- a/packages/assets-controllers/src/TokensController.test.ts +++ b/packages/assets-controllers/src/TokensController.test.ts @@ -198,6 +198,84 @@ describe('TokensController', () => { }); }); + it('should add tokens and update existing ones and detected tokens', async () => { + const selectedAddress = '0x0001'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); + await withController( + { + mockNetworkClientConfigurationsByNetworkClientId: { + networkClientId1: buildCustomNetworkClientConfiguration({ + chainId: '0x1', + }), + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ controller }) => { + await controller.addDetectedTokens( + [ + { + address: '0x01', + symbol: 'barA', + decimals: 2, + }, + ], + { + selectedAddress: '0x0001', + chainId: '0x1', + }, + ); + + await controller.addTokens( + [ + { + address: '0x01', + symbol: 'barA', + decimals: 2, + aggregators: [], + name: 'Token1', + }, + { + address: '0x02', + symbol: 'barB', + decimals: 2, + aggregators: [], + name: 'Token2', + }, + ], + 'networkClientId1', + ); + + expect(controller.state.allTokens).toStrictEqual({ + '0x1': { + '0x0001': [ + { + address: '0x01', + symbol: 'barA', + decimals: 2, + aggregators: [], + name: 'Token1', + image: undefined, + }, + { + address: '0x02', + symbol: 'barB', + decimals: 2, + aggregators: [], + name: 'Token2', + image: undefined, + }, + ], + }, + }); + }, + ); + }); + it('should add detected tokens', async () => { await withController(async ({ controller }) => { await controller.addDetectedTokens([ @@ -2142,6 +2220,66 @@ describe('TokensController', () => { }, ); }); + + it('should clear allDetectedTokens under chain ID and selected address when a detected token is added to tokens list', async () => { + const selectedAddress = '0x1'; + const selectedAccount = createMockInternalAccount({ + address: selectedAddress, + }); + const tokenAddress = '0x01'; + const dummyDetectedTokens = [ + { + address: tokenAddress, + symbol: 'barA', + decimals: 2, + aggregators: [], + isERC721: undefined, + name: undefined, + image: undefined, + }, + ]; + const dummyTokens = [ + { + address: tokenAddress, + symbol: 'barA', + decimals: 2, + aggregators: [], + isERC721: undefined, + name: undefined, + image: undefined, + }, + ]; + + await withController( + { + options: { + chainId: ChainId.mainnet, + }, + mocks: { + getSelectedAccount: selectedAccount, + }, + }, + async ({ controller }) => { + // First, add detected tokens + await controller.addDetectedTokens(dummyDetectedTokens); + expect( + controller.state.allDetectedTokens[ChainId.mainnet][ + selectedAddress + ], + ).toStrictEqual(dummyDetectedTokens); + + // Now, add the same token to the tokens list + await controller.addTokens(dummyTokens); + + // Check that allDetectedTokens for the selected address is cleared + expect( + controller.state.allDetectedTokens[ChainId.mainnet][ + selectedAddress + ], + ).toStrictEqual([]); + }, + ); + }); }); describe('when TokenListController:stateChange is published', () => { diff --git a/packages/assets-controllers/src/TokensController.ts b/packages/assets-controllers/src/TokensController.ts index fe680cdd0d..fbf135799a 100644 --- a/packages/assets-controllers/src/TokensController.ts +++ b/packages/assets-controllers/src/TokensController.ts @@ -462,13 +462,16 @@ export class TokensController extends BaseController< */ async addTokens(tokensToImport: Token[], networkClientId?: NetworkClientId) { const releaseLock = await this.#mutex.acquire(); - const { tokens, detectedTokens, ignoredTokens } = this.state; + const { ignoredTokens, allDetectedTokens } = this.state; const importedTokensMap: { [key: string]: true } = {}; // Used later to dedupe imported tokens - const newTokensMap = tokens.reduce((output, current) => { - output[current.address] = current; - return output; - }, {} as { [address: string]: Token }); + const newTokensMap = Object.values(tokensToImport).reduce( + (output, token) => { + output[token.address] = token; + return output; + }, + {} as { [address: string]: Token }, + ); try { tokensToImport.forEach((tokenToAdd) => { const { address, symbol, decimals, image, aggregators, name } = @@ -488,9 +491,6 @@ export class TokensController extends BaseController< }); const newTokens = Object.values(newTokensMap); - const newDetectedTokens = detectedTokens.filter( - (token) => !importedTokensMap[token.address.toLowerCase()], - ); const newIgnoredTokens = ignoredTokens.filter( (tokenAddress) => !newTokensMap[tokenAddress.toLowerCase()], ); @@ -503,6 +503,16 @@ export class TokensController extends BaseController< ).configuration.chainId; } + const newDetectedTokens = + interactingChainId && + allDetectedTokens[interactingChainId]?.[this.#getSelectedAddress()] + ? allDetectedTokens[interactingChainId]?.[ + this.#getSelectedAddress() + ].filter( + (token: Token) => !importedTokensMap[token.address.toLowerCase()], + ) + : []; + const { newAllTokens, newAllDetectedTokens, newAllIgnoredTokens } = this.#getNewAllTokensState({ newTokens, diff --git a/packages/assets-controllers/src/multi-chain-accounts-service/mocks/mock-get-balances.ts b/packages/assets-controllers/src/multi-chain-accounts-service/mocks/mock-get-balances.ts index 08b0b98a44..3ac5f71697 100644 --- a/packages/assets-controllers/src/multi-chain-accounts-service/mocks/mock-get-balances.ts +++ b/packages/assets-controllers/src/multi-chain-accounts-service/mocks/mock-get-balances.ts @@ -61,6 +61,24 @@ export const MOCK_GET_BALANCES_RESPONSE: GetBalancesResponse = { balance: '100.000000000000000000', chainId: 59144, }, + { + object: 'token', + address: '0x514910771AF9Ca656af840dff83E8264EcF986CA', + name: 'Chainlink', + symbol: 'LINK', + decimals: 18, + balance: '10', + chainId: 1, + }, + { + object: 'token', + address: '0x514910771AF9Ca656af840dff83E8264EcF986CA', + name: 'Chainlink', + symbol: 'LINK', + decimals: 18, + balance: '10', + chainId: 137, + }, ], unprocessedNetworks: [], };