diff --git a/app/scripts/controllers/detect-tokens.js b/app/scripts/controllers/detect-tokens.js index 46719aebcced..86c8cc222856 100644 --- a/app/scripts/controllers/detect-tokens.js +++ b/app/scripts/controllers/detect-tokens.js @@ -28,7 +28,6 @@ export default class DetectTokensController { * @param config.interval * @param config.preferences * @param config.network - * @param config.keyringMemStore * @param config.tokenList * @param config.tokensController * @param config.assetsContractController @@ -40,7 +39,6 @@ export default class DetectTokensController { interval = DEFAULT_INTERVAL, preferences, network, - keyringMemStore, tokenList, tokensController, assetsContractController = null, @@ -52,7 +50,6 @@ export default class DetectTokensController { this.preferences = preferences; this.interval = interval; this.network = network; - this.keyringMemStore = keyringMemStore; this.tokenList = tokenList; this.useTokenDetection = this.preferences?.store.getState().useTokenDetection; @@ -91,6 +88,8 @@ export default class DetectTokensController { this.restartTokenDetection({ chainId: this.chainId }); } }); + + this.#registerKeyringHandlers(); } /** @@ -236,26 +235,6 @@ export default class DetectTokensController { }, interval); } - /** - * In setter when isUnlocked is updated to true, detectNewTokens and restart polling - * - * @type {object} - */ - set keyringMemStore(keyringMemStore) { - if (!keyringMemStore) { - return; - } - this._keyringMemStore = keyringMemStore; - this._keyringMemStore.subscribe(({ isUnlocked }) => { - if (this.isUnlocked !== isUnlocked) { - this.isUnlocked = isUnlocked; - if (isUnlocked) { - this.restartTokenDetection(); - } - } - }); - } - /** * @type {object} */ @@ -275,4 +254,22 @@ export default class DetectTokensController { return this.isOpen && this.isUnlocked; } /* eslint-enable accessor-pairs */ + + /** + * Constructor helper to register listeners on the keyring + * locked state changes + */ + #registerKeyringHandlers() { + const { isUnlocked } = this.messenger.call('KeyringController:getState'); + this.isUnlocked = isUnlocked; + + this.messenger.subscribe('KeyringController:unlock', () => { + this.isUnlocked = true; + this.restartTokenDetection(); + }); + + this.messenger.subscribe('KeyringController:lock', () => { + this.isUnlocked = false; + }); + } } diff --git a/app/scripts/controllers/detect-tokens.test.js b/app/scripts/controllers/detect-tokens.test.js index 8d1c182c249e..c7306dc83295 100644 --- a/app/scripts/controllers/detect-tokens.test.js +++ b/app/scripts/controllers/detect-tokens.test.js @@ -1,7 +1,6 @@ import { strict as assert } from 'assert'; import sinon from 'sinon'; import nock from 'nock'; -import { ObservableStore } from '@metamask/obs-store'; import BigNumber from 'bignumber.js'; import { ControllerMessenger } from '@metamask/base-controller'; import { @@ -16,25 +15,30 @@ import { toChecksumHexAddress } from '../../../shared/modules/hexstring-utils'; import DetectTokensController from './detect-tokens'; import PreferencesController from './preferences'; -function buildMessenger() { - return new ControllerMessenger().getRestricted({ - name: 'DetectTokensController', - allowedEvents: ['NetworkController:stateChange'], - }); -} - describe('DetectTokensController', function () { let sandbox, assetsContractController, - keyringMemStore, network, preferences, provider, tokensController, - tokenListController; + tokenListController, + messenger; const noop = () => undefined; + const getRestrictedMessenger = () => { + return messenger.getRestricted({ + name: 'DetectTokensController', + allowedActions: ['KeyringController:getState'], + allowedEvents: [ + 'NetworkController:stateChange', + 'KeyringController:lock', + 'KeyringController:unlock', + ], + }); + }; + const networkControllerProviderConfig = { getAccounts: noop, }; @@ -198,7 +202,11 @@ describe('DetectTokensController', function () { .reply(200, { error: 'ChainId 3 is not supported' }) .persist(); - keyringMemStore = new ObservableStore({ isUnlocked: false }); + messenger = new ControllerMessenger(); + messenger.registerActionHandler('KeyringController:getState', () => ({ + isUnlocked: true, + })); + const networkControllerMessenger = new ControllerMessenger(); network = new NetworkController({ messenger: networkControllerMessenger, @@ -263,7 +271,11 @@ describe('DetectTokensController', function () { it('should poll on correct interval', async function () { const stub = sinon.stub(global, 'setInterval'); - new DetectTokensController({ messenger: buildMessenger(), interval: 1337 }); // eslint-disable-line no-new + // eslint-disable-next-line no-new + new DetectTokensController({ + messenger: getRestrictedMessenger(), + interval: 1337, + }); assert.strictEqual(stub.getCall(0).args[1], 1337); stub.restore(); }); @@ -272,10 +284,9 @@ describe('DetectTokensController', function () { const clock = sandbox.useFakeTimers(); await network.setProviderType(NETWORK_TYPES.MAINNET); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, - keyringMemStore, tokenList: tokenListController, tokensController, assetsContractController, @@ -309,10 +320,9 @@ describe('DetectTokensController', function () { }); await tokenListController.start(); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, - keyringMemStore, tokenList: tokenListController, tokensController, assetsContractController, @@ -333,10 +343,9 @@ describe('DetectTokensController', function () { sandbox.useFakeTimers(); await network.setProviderType(NETWORK_TYPES.MAINNET); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, - keyringMemStore, tokenList: tokenListController, tokensController, assetsContractController, @@ -385,10 +394,9 @@ describe('DetectTokensController', function () { sandbox.useFakeTimers(); await network.setProviderType(NETWORK_TYPES.MAINNET); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, - keyringMemStore, tokenList: tokenListController, tokensController, assetsContractController, @@ -444,10 +452,9 @@ describe('DetectTokensController', function () { it('should trigger detect new tokens when change address', async function () { sandbox.useFakeTimers(); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, - keyringMemStore, tokenList: tokenListController, tokensController, assetsContractController, @@ -464,10 +471,9 @@ describe('DetectTokensController', function () { it('should trigger detect new tokens when submit password', async function () { sandbox.useFakeTimers(); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, - keyringMemStore, tokenList: tokenListController, tokensController, assetsContractController, @@ -475,18 +481,38 @@ describe('DetectTokensController', function () { controller.isOpen = true; controller.selectedAddress = '0x0'; const stub = sandbox.stub(controller, 'detectNewTokens'); - await controller._keyringMemStore.updateState({ isUnlocked: true }); + + messenger.publish('KeyringController:unlock'); + sandbox.assert.called(stub); + assert.equal(controller.isUnlocked, true); + }); + + it('should not be active after lock event is emitted', async function () { + sandbox.useFakeTimers(); + const controller = new DetectTokensController({ + messenger: getRestrictedMessenger(), + preferences, + network, + tokenList: tokenListController, + tokensController, + assetsContractController, + }); + controller.isOpen = true; + + messenger.publish('KeyringController:lock'); + + assert.equal(controller.isUnlocked, false); + assert.equal(controller.isActive, false); }); it('should not trigger detect new tokens when not unlocked', async function () { const clock = sandbox.useFakeTimers(); await network.setProviderType(NETWORK_TYPES.MAINNET); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, - keyringMemStore, tokenList: tokenListController, tokensController, assetsContractController, @@ -505,10 +531,9 @@ describe('DetectTokensController', function () { const clock = sandbox.useFakeTimers(); await network.setProviderType(NETWORK_TYPES.MAINNET); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, - keyringMemStore, tokensController, assetsContractController, }); diff --git a/app/scripts/metamask-controller.js b/app/scripts/metamask-controller.js index ee08fe92682b..fd5c6427755f 100644 --- a/app/scripts/metamask-controller.js +++ b/app/scripts/metamask-controller.js @@ -1142,7 +1142,12 @@ export default class MetamaskController extends EventEmitter { const detectTokensControllerMessenger = this.controllerMessenger.getRestricted({ name: 'DetectTokensController', - allowedEvents: ['NetworkController:stateChange'], + allowedActions: ['KeyringController:getState'], + allowedEvents: [ + 'NetworkController:stateChange', + 'KeyringController:lock', + 'KeyringController:unlock', + ], }); this.detectTokensController = new DetectTokensController({ messenger: detectTokensControllerMessenger, @@ -1150,7 +1155,6 @@ export default class MetamaskController extends EventEmitter { tokensController: this.tokensController, assetsContractController: this.assetsContractController, network: this.networkController, - keyringMemStore: this.keyringController.memStore, tokenList: this.tokenListController, trackMetaMetricsEvent: this.metaMetricsController.trackEvent.bind( this.metaMetricsController,