diff --git a/app/scripts/controllers/detect-tokens.js b/app/scripts/controllers/detect-tokens.js index c1c7efea07e4..86c8cc222856 100644 --- a/app/scripts/controllers/detect-tokens.js +++ b/app/scripts/controllers/detect-tokens.js @@ -28,7 +28,6 @@ export default class DetectTokensController { * @param config.interval * @param config.preferences * @param config.network - * @param config.keyringMemStore * @param config.tokenList * @param config.tokensController * @param config.assetsContractController @@ -89,13 +88,8 @@ export default class DetectTokensController { this.restartTokenDetection({ chainId: this.chainId }); } }); - messenger.subscribe('KeyringController:unlock', () => { - this.isUnlocked = true; - this.restartTokenDetection(); - }); - messenger.subscribe('KeyringController:lock', () => { - this.isUnlocked = false; - }); + + this.#registerKeyringHandlers(); } /** @@ -260,4 +254,22 @@ export default class DetectTokensController { return this.isOpen && this.isUnlocked; } /* eslint-enable accessor-pairs */ + + /** + * Constructor helper to register listeners on the keyring + * locked state changes + */ + #registerKeyringHandlers() { + const { isUnlocked } = this.messenger.call('KeyringController:getState'); + this.isUnlocked = isUnlocked; + + this.messenger.subscribe('KeyringController:unlock', () => { + this.isUnlocked = true; + this.restartTokenDetection(); + }); + + this.messenger.subscribe('KeyringController:lock', () => { + this.isUnlocked = false; + }); + } } diff --git a/app/scripts/controllers/detect-tokens.test.js b/app/scripts/controllers/detect-tokens.test.js index 25a959cf9102..0db3be07e5c4 100644 --- a/app/scripts/controllers/detect-tokens.test.js +++ b/app/scripts/controllers/detect-tokens.test.js @@ -15,17 +15,6 @@ import { toChecksumHexAddress } from '../../../shared/modules/hexstring-utils'; import DetectTokensController from './detect-tokens'; import PreferencesController from './preferences'; -function buildMessenger() { - return new ControllerMessenger().getRestricted({ - name: 'DetectTokensController', - allowedEvents: [ - 'NetworkController:stateChange', - 'KeyringController:lock', - 'KeyringController:unlock', - ], - }); -} - describe('DetectTokensController', function () { let sandbox, assetsContractController, @@ -33,10 +22,23 @@ describe('DetectTokensController', function () { preferences, provider, tokensController, - tokenListController; + tokenListController, + messenger; const noop = () => undefined; + const getRestrictedMessenger = () => { + return messenger.getRestricted({ + name: 'DetectTokensController', + allowedActions: ['KeyringController:getState'], + allowedEvents: [ + 'NetworkController:stateChange', + 'KeyringController:lock', + 'KeyringController:unlock', + ], + }); + }; + const networkControllerProviderConfig = { getAccounts: noop, }; @@ -200,6 +202,11 @@ describe('DetectTokensController', function () { .reply(200, { error: 'ChainId 3 is not supported' }) .persist(); + messenger = new ControllerMessenger(); + messenger.registerActionHandler('KeyringController:getState', () => ({ + isUnlocked: true, + })); + const networkControllerMessenger = new ControllerMessenger(); network = new NetworkController({ messenger: networkControllerMessenger, @@ -264,7 +271,11 @@ describe('DetectTokensController', function () { it('should poll on correct interval', async function () { const stub = sinon.stub(global, 'setInterval'); - new DetectTokensController({ messenger: buildMessenger(), interval: 1337 }); // eslint-disable-line no-new + // eslint-disable-next-line no-new + new DetectTokensController({ + messenger: getRestrictedMessenger(), + interval: 1337, + }); assert.strictEqual(stub.getCall(0).args[1], 1337); stub.restore(); }); @@ -273,7 +284,7 @@ describe('DetectTokensController', function () { const clock = sandbox.useFakeTimers(); await network.setProviderType(NETWORK_TYPES.MAINNET); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, tokenList: tokenListController, @@ -309,7 +320,7 @@ describe('DetectTokensController', function () { }); await tokenListController.start(); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, tokenList: tokenListController, @@ -332,7 +343,7 @@ describe('DetectTokensController', function () { sandbox.useFakeTimers(); await network.setProviderType(NETWORK_TYPES.MAINNET); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, tokenList: tokenListController, @@ -383,7 +394,7 @@ describe('DetectTokensController', function () { sandbox.useFakeTimers(); await network.setProviderType(NETWORK_TYPES.MAINNET); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, tokenList: tokenListController, @@ -441,7 +452,7 @@ describe('DetectTokensController', function () { it('should trigger detect new tokens when change address', async function () { sandbox.useFakeTimers(); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, tokenList: tokenListController, @@ -459,9 +470,8 @@ describe('DetectTokensController', function () { it('should trigger detect new tokens when submit password', async function () { sandbox.useFakeTimers(); - const messenger = buildMessenger(); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, tokenList: tokenListController, @@ -481,7 +491,7 @@ describe('DetectTokensController', function () { const clock = sandbox.useFakeTimers(); await network.setProviderType(NETWORK_TYPES.MAINNET); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, tokenList: tokenListController, @@ -502,7 +512,7 @@ describe('DetectTokensController', function () { const clock = sandbox.useFakeTimers(); await network.setProviderType(NETWORK_TYPES.MAINNET); const controller = new DetectTokensController({ - messenger: buildMessenger(), + messenger: getRestrictedMessenger(), preferences, network, tokensController, diff --git a/app/scripts/metamask-controller.js b/app/scripts/metamask-controller.js index fa8080932bae..fd5c6427755f 100644 --- a/app/scripts/metamask-controller.js +++ b/app/scripts/metamask-controller.js @@ -1142,6 +1142,7 @@ export default class MetamaskController extends EventEmitter { const detectTokensControllerMessenger = this.controllerMessenger.getRestricted({ name: 'DetectTokensController', + allowedActions: ['KeyringController:getState'], allowedEvents: [ 'NetworkController:stateChange', 'KeyringController:lock',