From 154172d5f852c3899e8cfe18e9f3349324aa231c Mon Sep 17 00:00:00 2001 From: Jyoti Puri Date: Fri, 2 Dec 2022 23:29:03 +0530 Subject: [PATCH] Network request in background should not start until onboarding is completed (#16773) --- .../controllers/incoming-transactions.js | 50 +++++++--------- .../controllers/incoming-transactions.test.js | 36 ++++++++++++ app/scripts/lib/account-tracker.js | 16 ++++++ app/scripts/lib/util.js | 28 +++++++++ app/scripts/metamask-controller.js | 57 +++++++++++++------ 5 files changed, 140 insertions(+), 47 deletions(-) diff --git a/app/scripts/controllers/incoming-transactions.js b/app/scripts/controllers/incoming-transactions.js index bb47cdac3..361399483 100644 --- a/app/scripts/controllers/incoming-transactions.js +++ b/app/scripts/controllers/incoming-transactions.js @@ -2,7 +2,7 @@ import { ObservableStore } from '@metamask/obs-store'; import log from 'loglevel'; import BN from 'bn.js'; import createId from '../../../shared/modules/random-id'; -import { bnToHex } from '../lib/util'; +import { bnToHex, previousValueComparator } from '../lib/util'; import getFetchWithTimeout from '../../../shared/modules/fetch-with-timeout'; import { @@ -61,10 +61,12 @@ export default class IncomingTransactionsController { onNetworkDidChange, getCurrentChainId, preferencesController, + onboardingController, } = opts; this.blockTracker = blockTracker; this.getCurrentChainId = getCurrentChainId; this.preferencesController = preferencesController; + this.onboardingController = onboardingController; this._onLatestBlock = async (newBlockNumberHex) => { const selectedAddress = this.preferencesController.getSelectedAddress(); @@ -121,6 +123,17 @@ export default class IncomingTransactionsController { }, this.preferencesController.store.getState()), ); + this.onboardingController.store.subscribe( + previousValueComparator(async (prevState, currState) => { + const { completedOnboarding: prevCompletedOnboarding } = prevState; + const { completedOnboarding: currCompletedOnboarding } = currState; + if (!prevCompletedOnboarding && currCompletedOnboarding) { + const address = this.preferencesController.getSelectedAddress(); + await this._update(address); + } + }, this.onboardingController.store.getState()), + ); + onNetworkDidChange(async () => { const address = this.preferencesController.getSelectedAddress(); await this._update(address); @@ -154,8 +167,13 @@ export default class IncomingTransactionsController { * @param {number} [newBlockNumberDec] - block number to begin fetching from */ async _update(address, newBlockNumberDec) { + const { completedOnboarding } = this.onboardingController.store.getState(); const chainId = this.getCurrentChainId(); - if (!etherscanSupportedNetworks.includes(chainId) || !address) { + if ( + !etherscanSupportedNetworks.includes(chainId) || + !address || + !completedOnboarding + ) { return; } try { @@ -293,31 +311,3 @@ export default class IncomingTransactionsController { }; } } - -/** - * Returns a function with arity 1 that caches the argument that the function - * is called with and invokes the comparator with both the cached, previous, - * value and the current value. If specified, the initialValue will be passed - * in as the previous value on the first invocation of the returned method. - * - * @template A - The type of the compared value. - * @param {(prevValue: A, nextValue: A) => void} comparator - A method to compare - * the previous and next values. - * @param {A} [initialValue] - The initial value to supply to prevValue - * on first call of the method. - */ -function previousValueComparator(comparator, initialValue) { - let first = true; - let cache; - return (value) => { - try { - if (first) { - first = false; - return comparator(initialValue ?? value, value); - } - return comparator(cache, value); - } finally { - cache = value; - } - }; -} diff --git a/app/scripts/controllers/incoming-transactions.test.js b/app/scripts/controllers/incoming-transactions.test.js index fd11163ff..2675c62ce 100644 --- a/app/scripts/controllers/incoming-transactions.test.js +++ b/app/scripts/controllers/incoming-transactions.test.js @@ -77,6 +77,17 @@ function getMockPreferencesController({ }; } +function getMockOnboardingController() { + return { + store: { + getState: sinon.stub().returns({ + completedOnboarding: true, + }), + subscribe: sinon.spy(), + }, + }; +} + function getMockBlockTracker() { return { addListener: sinon.stub().callsArgWithAsync(1, '0xa'), @@ -169,6 +180,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...mockedNetworkMethods, preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: {}, }, ); @@ -199,6 +211,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -217,6 +230,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: {}, }, ); @@ -239,6 +253,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -344,6 +359,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -394,6 +410,7 @@ describe('IncomingTransactionsController', function () { preferencesController: getMockPreferencesController({ showIncomingTransactions: false, }), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -441,6 +458,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -486,6 +504,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -533,6 +552,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -624,6 +644,7 @@ describe('IncomingTransactionsController', function () { blockTracker: { ...getMockBlockTracker() }, ...getMockNetworkControllerMethods(), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -685,6 +706,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...mockedNetworkMethods, preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -768,6 +790,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...mockedNetworkMethods, preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -822,6 +845,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getEmptyInitState(), getCurrentChainId: () => CHAIN_IDS.GOERLI, }); @@ -858,6 +882,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getEmptyInitState(), getCurrentChainId: () => CHAIN_IDS.GOERLI, }); @@ -911,6 +936,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), getCurrentChainId: () => CHAIN_IDS.GOERLI, }); @@ -951,6 +977,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), getCurrentChainId: () => CHAIN_IDS.GOERLI, }, @@ -1026,6 +1053,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -1049,6 +1077,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.MAINNET), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -1072,6 +1101,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -1095,6 +1125,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -1128,6 +1159,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -1156,6 +1188,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -1179,6 +1212,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -1225,6 +1259,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); @@ -1271,6 +1306,7 @@ describe('IncomingTransactionsController', function () { blockTracker: getMockBlockTracker(), ...getMockNetworkControllerMethods(CHAIN_IDS.GOERLI), preferencesController: getMockPreferencesController(), + onboardingController: getMockOnboardingController(), initState: getNonEmptyInitState(), }, ); diff --git a/app/scripts/lib/account-tracker.js b/app/scripts/lib/account-tracker.js index 5ec2d8a23..ddbaf7ec4 100644 --- a/app/scripts/lib/account-tracker.js +++ b/app/scripts/lib/account-tracker.js @@ -30,6 +30,7 @@ import { SINGLE_CALL_BALANCES_ADDRESS_FANTOM, SINGLE_CALL_BALANCES_ADDRESS_ARBITRUM, } from '../constants/contracts'; +import { previousValueComparator } from './util'; /** * This module is responsible for tracking any number of accounts and caching their current balances & transaction @@ -79,8 +80,19 @@ export default class AccountTracker { this.getCurrentChainId = opts.getCurrentChainId; this.getNetworkIdentifier = opts.getNetworkIdentifier; this.preferencesController = opts.preferencesController; + this.onboardingController = opts.onboardingController; this.ethersProvider = new ethers.providers.Web3Provider(this._provider); + + this.onboardingController.store.subscribe( + previousValueComparator(async (prevState, currState) => { + const { completedOnboarding: prevCompletedOnboarding } = prevState; + const { completedOnboarding: currCompletedOnboarding } = currState; + if (!prevCompletedOnboarding && currCompletedOnboarding) { + this._updateAccounts(); + } + }, this.onboardingController.store.getState()), + ); } start() { @@ -206,6 +218,10 @@ export default class AccountTracker { * @returns {Promise} after all account balances updated */ async _updateAccounts() { + const { completedOnboarding } = this.onboardingController.store.getState(); + if (!completedOnboarding) { + return; + } const { useMultiAccountBalanceChecker } = this.preferencesController.store.getState(); diff --git a/app/scripts/lib/util.js b/app/scripts/lib/util.js index 929b5d602..1f19901f1 100644 --- a/app/scripts/lib/util.js +++ b/app/scripts/lib/util.js @@ -218,3 +218,31 @@ export function deferredPromise() { }); return { promise, resolve, reject }; } + +/** + * Returns a function with arity 1 that caches the argument that the function + * is called with and invokes the comparator with both the cached, previous, + * value and the current value. If specified, the initialValue will be passed + * in as the previous value on the first invocation of the returned method. + * + * @template A - The type of the compared value. + * @param {(prevValue: A, nextValue: A) => void} comparator - A method to compare + * the previous and next values. + * @param {A} [initialValue] - The initial value to supply to prevValue + * on first call of the method. + */ +export function previousValueComparator(comparator, initialValue) { + let first = true; + let cache; + return (value) => { + try { + if (first) { + first = false; + return comparator(initialValue ?? value, value); + } + return comparator(cache, value); + } finally { + cache = value; + } + }; +} diff --git a/app/scripts/metamask-controller.js b/app/scripts/metamask-controller.js index d07f58b62..089423135 100644 --- a/app/scripts/metamask-controller.js +++ b/app/scripts/metamask-controller.js @@ -147,6 +147,7 @@ import seedPhraseVerifier from './lib/seed-phrase-verifier'; import MetaMetricsController from './controllers/metametrics'; import { segment } from './lib/segment'; import createMetaRPCHandler from './lib/createMetaRPCHandler'; +import { previousValueComparator } from './lib/util'; import { CaveatMutatorFactories, @@ -528,6 +529,10 @@ export default class MetamaskController extends EventEmitter { ), }); + this.onboardingController = new OnboardingController({ + initState: initState.OnboardingController, + }); + this.incomingTransactionsController = new IncomingTransactionsController({ blockTracker: this.blockTracker, onNetworkDidChange: this.networkController.on.bind( @@ -538,6 +543,7 @@ export default class MetamaskController extends EventEmitter { this.networkController, ), preferencesController: this.preferencesController, + onboardingController: this.onboardingController, initState: initState.IncomingTransactionsController, }); @@ -552,27 +558,30 @@ export default class MetamaskController extends EventEmitter { this.networkController, ), preferencesController: this.preferencesController, + onboardingController: this.onboardingController, }); // start and stop polling for balances based on activeControllerConnections this.on('controllerConnectionChanged', (activeControllerConnections) => { - if (activeControllerConnections > 0) { - this.accountTracker.start(); - this.incomingTransactionsController.start(); - this.currencyRateController.start(); - if (this.preferencesController.store.getState().useTokenDetection) { - this.tokenListController.start(); - } + const { completedOnboarding } = + this.onboardingController.store.getState(); + if (activeControllerConnections > 0 && completedOnboarding) { + this.triggerNetworkrequests(); } else { - this.accountTracker.stop(); - this.incomingTransactionsController.stop(); - this.currencyRateController.stop(); - if (this.preferencesController.store.getState().useTokenDetection) { - this.tokenListController.stop(); - } + this.stopNetworkRequests(); } }); + this.onboardingController.store.subscribe( + previousValueComparator(async (prevState, currState) => { + const { completedOnboarding: prevCompletedOnboarding } = prevState; + const { completedOnboarding: currCompletedOnboarding } = currState; + if (!prevCompletedOnboarding && currCompletedOnboarding) { + this.triggerNetworkrequests(); + } + }, this.onboardingController.store.getState()), + ); + this.cachedBalancesController = new CachedBalancesController({ accountTracker: this.accountTracker, getCurrentChainId: this.networkController.getCurrentChainId.bind( @@ -581,10 +590,6 @@ export default class MetamaskController extends EventEmitter { initState: initState.CachedBalancesController, }); - this.onboardingController = new OnboardingController({ - initState: initState.OnboardingController, - }); - this.tokensController.hub.on('pendingSuggestedAsset', async () => { await opts.openPopup(); }); @@ -1192,6 +1197,24 @@ export default class MetamaskController extends EventEmitter { checkForMultipleVersionsRunning(); } + triggerNetworkrequests() { + this.accountTracker.start(); + this.incomingTransactionsController.start(); + this.currencyRateController.start(); + if (this.preferencesController.store.getState().useTokenDetection) { + this.tokenListController.start(); + } + } + + stopNetworkRequests() { + this.accountTracker.stop(); + this.incomingTransactionsController.stop(); + this.currencyRateController.stop(); + if (this.preferencesController.store.getState().useTokenDetection) { + this.tokenListController.stop(); + } + } + resetStates(resetMethods) { resetMethods.forEach((resetMethod) => { try {