1
0
mirror of https://github.com/kremalicious/metamask-extension.git synced 2024-12-23 09:52:26 +01:00

Use getKeyringForAccount from core KeyringController (#20202)

This commit is contained in:
cryptodev-2s 2023-07-28 20:09:14 +01:00 committed by GitHub
parent 537f1c7aee
commit b576c5245c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 49 deletions

View File

@ -19,7 +19,7 @@ const messageIdMock2 = '456';
const stateMock = { test: 123 }; const stateMock = { test: 123 };
const addressMock = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d'; const addressMock = '0xc38bf1ad06ef69f0c04e29dbeb4152b4175f0a8d';
const publicKeyMock = '32762347862378feb87123781623a='; const publicKeyMock = '32762347862378feb87123781623a=';
const keyringMock = { type: KeyringType.hdKeyTree }; const keyringTypeMock = KeyringType.hdKeyTree;
const messageParamsMock = { const messageParamsMock = {
from: addressMock, from: addressMock,
@ -73,11 +73,6 @@ const createEncryptionPublicKeyManagerMock = <T>() =>
}, },
} as any as jest.Mocked<T>); } as any as jest.Mocked<T>);
const createKeyringControllerMock = () => ({
getKeyringForAccount: jest.fn(),
getEncryptionPublicKey: jest.fn(),
});
describe('EncryptionPublicKeyController', () => { describe('EncryptionPublicKeyController', () => {
let encryptionPublicKeyController: EncryptionPublicKeyController; let encryptionPublicKeyController: EncryptionPublicKeyController;
@ -88,7 +83,8 @@ describe('EncryptionPublicKeyController', () => {
const encryptionPublicKeyManagerMock = const encryptionPublicKeyManagerMock =
createEncryptionPublicKeyManagerMock<EncryptionPublicKeyManager>(); createEncryptionPublicKeyManagerMock<EncryptionPublicKeyManager>();
const messengerMock = createMessengerMock(); const messengerMock = createMessengerMock();
const keyringControllerMock = createKeyringControllerMock(); const getEncryptionPublicKeyMock = jest.fn();
const getAccountKeyringTypeMock = jest.fn();
const getStateMock = jest.fn(); const getStateMock = jest.fn();
const metricsEventMock = jest.fn(); const metricsEventMock = jest.fn();
@ -101,7 +97,8 @@ describe('EncryptionPublicKeyController', () => {
encryptionPublicKeyController = new EncryptionPublicKeyController({ encryptionPublicKeyController = new EncryptionPublicKeyController({
messenger: messengerMock as any, messenger: messengerMock as any,
keyringController: keyringControllerMock as any, getEncryptionPublicKey: getEncryptionPublicKeyMock as any,
getAccountKeyringType: getAccountKeyringTypeMock as any,
getState: getStateMock as any, getState: getStateMock as any,
metricsEvent: metricsEventMock as any, metricsEvent: metricsEventMock as any,
} as EncryptionPublicKeyControllerOptions); } as EncryptionPublicKeyControllerOptions);
@ -203,9 +200,7 @@ describe('EncryptionPublicKeyController', () => {
])( ])(
'throws if keyring is not supported', 'throws if keyring is not supported',
async (keyringName, keyringType) => { async (keyringName, keyringType) => {
keyringControllerMock.getKeyringForAccount.mockResolvedValueOnce({ getAccountKeyringTypeMock.mockResolvedValueOnce(keyringType);
type: keyringType,
});
await expect( await expect(
encryptionPublicKeyController.newRequestEncryptionPublicKey( encryptionPublicKeyController.newRequestEncryptionPublicKey(
@ -219,9 +214,7 @@ describe('EncryptionPublicKeyController', () => {
); );
it('adds message to message manager', async () => { it('adds message to message manager', async () => {
keyringControllerMock.getKeyringForAccount.mockResolvedValueOnce( getAccountKeyringTypeMock.mockResolvedValueOnce(keyringTypeMock);
keyringMock,
);
await encryptionPublicKeyController.newRequestEncryptionPublicKey( await encryptionPublicKeyController.newRequestEncryptionPublicKey(
addressMock, addressMock,
@ -243,9 +236,7 @@ describe('EncryptionPublicKeyController', () => {
from: messageParamsMock.data, from: messageParamsMock.data,
}); });
keyringControllerMock.getEncryptionPublicKey.mockResolvedValueOnce( getEncryptionPublicKeyMock.mockResolvedValueOnce(publicKeyMock);
publicKeyMock,
);
}); });
it('approves message and signs', async () => { it('approves message and signs', async () => {
@ -253,10 +244,8 @@ describe('EncryptionPublicKeyController', () => {
messageParamsMock, messageParamsMock,
); );
expect( expect(getEncryptionPublicKeyMock).toHaveBeenCalledTimes(1);
keyringControllerMock.getEncryptionPublicKey, expect(getEncryptionPublicKeyMock).toHaveBeenCalledWith(
).toHaveBeenCalledTimes(1);
expect(keyringControllerMock.getEncryptionPublicKey).toHaveBeenCalledWith(
messageParamsMock.data, messageParamsMock.data,
); );
@ -294,10 +283,8 @@ describe('EncryptionPublicKeyController', () => {
}); });
it('rejects message on error', async () => { it('rejects message on error', async () => {
keyringControllerMock.getEncryptionPublicKey.mockReset(); getEncryptionPublicKeyMock.mockReset();
keyringControllerMock.getEncryptionPublicKey.mockRejectedValue( getEncryptionPublicKeyMock.mockRejectedValue(new Error('Test Error'));
new Error('Test Error'),
);
await expect( await expect(
encryptionPublicKeyController.encryptionPublicKey(messageParamsMock), encryptionPublicKeyController.encryptionPublicKey(messageParamsMock),
@ -312,10 +299,8 @@ describe('EncryptionPublicKeyController', () => {
}); });
it('rejects approval on error', async () => { it('rejects approval on error', async () => {
keyringControllerMock.getEncryptionPublicKey.mockReset(); getEncryptionPublicKeyMock.mockReset();
keyringControllerMock.getEncryptionPublicKey.mockRejectedValue( getEncryptionPublicKeyMock.mockRejectedValue(new Error('Test Error'));
new Error('Test Error'),
);
await expect( await expect(
encryptionPublicKeyController.encryptionPublicKey(messageParamsMock), encryptionPublicKeyController.encryptionPublicKey(messageParamsMock),

View File

@ -4,7 +4,6 @@ import {
EncryptionPublicKeyManager, EncryptionPublicKeyManager,
EncryptionPublicKeyParamsMetamask, EncryptionPublicKeyParamsMetamask,
} from '@metamask/message-manager'; } from '@metamask/message-manager';
import { KeyringController } from '@metamask/eth-keyring-controller';
import { import {
AbstractMessageManager, AbstractMessageManager,
AbstractMessage, AbstractMessage,
@ -83,7 +82,8 @@ export type EncryptionPublicKeyControllerMessenger =
export type EncryptionPublicKeyControllerOptions = { export type EncryptionPublicKeyControllerOptions = {
messenger: EncryptionPublicKeyControllerMessenger; messenger: EncryptionPublicKeyControllerMessenger;
keyringController: KeyringController; getEncryptionPublicKey: (address: string) => Promise<string>;
getAccountKeyringType: (account: string) => Promise<string>;
getState: () => any; getState: () => any;
metricsEvent: (payload: any, options?: any) => void; metricsEvent: (payload: any, options?: any) => void;
}; };
@ -98,7 +98,9 @@ export default class EncryptionPublicKeyController extends BaseControllerV2<
> { > {
hub: EventEmitter; hub: EventEmitter;
private _keyringController: KeyringController; private _getEncryptionPublicKey: (address: string) => Promise<string>;
private _getAccountKeyringType: (account: string) => Promise<string>;
private _getState: () => any; private _getState: () => any;
@ -111,13 +113,15 @@ export default class EncryptionPublicKeyController extends BaseControllerV2<
* *
* @param options - The controller options. * @param options - The controller options.
* @param options.messenger - The restricted controller messenger for the EncryptionPublicKey controller. * @param options.messenger - The restricted controller messenger for the EncryptionPublicKey controller.
* @param options.keyringController - An instance of a keyring controller used to extract the encryption public key. * @param options.getEncryptionPublicKey - Callback to get the keyring encryption public key.
* @param options.getAccountKeyringType - Callback to get the keyring type.
* @param options.getState - Callback to retrieve all user state. * @param options.getState - Callback to retrieve all user state.
* @param options.metricsEvent - A function for emitting a metric event. * @param options.metricsEvent - A function for emitting a metric event.
*/ */
constructor({ constructor({
messenger, messenger,
keyringController, getEncryptionPublicKey,
getAccountKeyringType,
getState, getState,
metricsEvent, metricsEvent,
}: EncryptionPublicKeyControllerOptions) { }: EncryptionPublicKeyControllerOptions) {
@ -128,7 +132,8 @@ export default class EncryptionPublicKeyController extends BaseControllerV2<
state: getDefaultState(), state: getDefaultState(),
}); });
this._keyringController = keyringController; this._getEncryptionPublicKey = getEncryptionPublicKey;
this._getAccountKeyringType = getAccountKeyringType;
this._getState = getState; this._getState = getState;
this._metricsEvent = metricsEvent; this._metricsEvent = metricsEvent;
@ -186,9 +191,9 @@ export default class EncryptionPublicKeyController extends BaseControllerV2<
address: string, address: string,
req: OriginalRequest, req: OriginalRequest,
): Promise<string> { ): Promise<string> {
const keyring = await this._keyringController.getKeyringForAccount(address); const keyringType = await this._getAccountKeyringType(address);
switch (keyring.type) { switch (keyringType) {
case KeyringType.ledger: { case KeyringType.ledger: {
return new Promise((_, reject) => { return new Promise((_, reject) => {
reject( reject(
@ -244,7 +249,7 @@ export default class EncryptionPublicKeyController extends BaseControllerV2<
await this._encryptionPublicKeyManager.approveMessage(msgParams); await this._encryptionPublicKeyManager.approveMessage(msgParams);
// EncryptionPublicKey message // EncryptionPublicKey message
const publicKey = await this._keyringController.getEncryptionPublicKey( const publicKey = await this._getEncryptionPublicKey(
cleanMessageParams.from, cleanMessageParams.from,
); );

View File

@ -904,9 +904,8 @@ export default class MetamaskController extends EventEmitter {
(address) => !identities[address], (address) => !identities[address],
); );
const keyringTypesWithMissingIdentities = const keyringTypesWithMissingIdentities =
accountsMissingIdentities.map( accountsMissingIdentities.map((address) =>
(address) => this.coreKeyringController.getAccountKeyringType(address),
this.keyringController.getKeyringForAccount(address)?.type,
); );
const identitiesCount = Object.keys(identities || {}).length; const identitiesCount = Object.keys(identities || {}).length;
@ -1349,7 +1348,14 @@ export default class MetamaskController extends EventEmitter {
`${this.approvalController.name}:rejectRequest`, `${this.approvalController.name}:rejectRequest`,
], ],
}), }),
keyringController: this.keyringController, getEncryptionPublicKey:
this.keyringController.getEncryptionPublicKey.bind(
this.keyringController,
),
getAccountKeyringType:
this.coreKeyringController.getAccountKeyringType.bind(
this.coreKeyringController,
),
getState: this.getState.bind(this), getState: this.getState.bind(this),
metricsEvent: this.metaMetricsController.trackEvent.bind( metricsEvent: this.metaMetricsController.trackEvent.bind(
this.metaMetricsController, this.metaMetricsController,
@ -3265,8 +3271,10 @@ export default class MetamaskController extends EventEmitter {
* @returns {'hardware' | 'imported' | 'MetaMask'} * @returns {'hardware' | 'imported' | 'MetaMask'}
*/ */
async getAccountType(address) { async getAccountType(address) {
const keyring = await this.keyringController.getKeyringForAccount(address); const keyringType = await this.coreKeyringController.getAccountKeyringType(
switch (keyring.type) { address,
);
switch (keyringType) {
case KeyringType.trezor: case KeyringType.trezor:
case KeyringType.lattice: case KeyringType.lattice:
case KeyringType.qr: case KeyringType.qr:
@ -3288,7 +3296,9 @@ export default class MetamaskController extends EventEmitter {
* @returns {'ledger' | 'lattice' | 'N/A' | string} * @returns {'ledger' | 'lattice' | 'N/A' | string}
*/ */
async getDeviceModel(address) { async getDeviceModel(address) {
const keyring = await this.keyringController.getKeyringForAccount(address); const keyring = await this.coreKeyringController.getKeyringForAccount(
address,
);
switch (keyring.type) { switch (keyring.type) {
case KeyringType.trezor: case KeyringType.trezor:
return keyring.getModel(); return keyring.getModel();
@ -3527,7 +3537,9 @@ export default class MetamaskController extends EventEmitter {
this.custodyController.removeAccount(address); this.custodyController.removeAccount(address);
///: END:ONLY_INCLUDE_IN(build-mmi) ///: END:ONLY_INCLUDE_IN(build-mmi)
const keyring = await this.keyringController.getKeyringForAccount(address); const keyring = await this.coreKeyringController.getKeyringForAccount(
address,
);
// Remove account from the keyring // Remove account from the keyring
await this.keyringController.removeAccount(address); await this.keyringController.removeAccount(address);
const updatedKeyringAccounts = keyring ? await keyring.getAccounts() : {}; const updatedKeyringAccounts = keyring ? await keyring.getAccounts() : {};

View File

@ -859,7 +859,10 @@ describe('MetaMaskController', function () {
sinon.stub(metamaskController.keyringController, 'removeAccount'); sinon.stub(metamaskController.keyringController, 'removeAccount');
sinon.stub(metamaskController, 'removeAllAccountPermissions'); sinon.stub(metamaskController, 'removeAllAccountPermissions');
sinon sinon
.stub(metamaskController.keyringController, 'getKeyringForAccount') .stub(
metamaskController.coreKeyringController,
'getKeyringForAccount',
)
.returns(Promise.resolve(mockKeyring)); .returns(Promise.resolve(mockKeyring));
ret = await metamaskController.removeAccount(addressToRemove); ret = await metamaskController.removeAccount(addressToRemove);
@ -906,9 +909,9 @@ describe('MetaMaskController', function () {
it('should return address', async function () { it('should return address', async function () {
assert.equal(ret, '0x1'); assert.equal(ret, '0x1');
}); });
it('should call keyringController.getKeyringForAccount', async function () { it('should call coreKeyringController.getKeyringForAccount', async function () {
assert( assert(
metamaskController.keyringController.getKeyringForAccount.calledWith( metamaskController.coreKeyringController.getKeyringForAccount.calledWith(
addressToRemove, addressToRemove,
), ),
); );