Skip to content

Commit

Permalink
refactor '_getHDNodeForAddress' to support change addresses
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoacosta74 authored and rileystephens28 committed Jun 17, 2024
1 parent 872539e commit 0d5cdb0
Showing 1 changed file with 66 additions and 36 deletions.
102 changes: 66 additions & 36 deletions src/wallet/hdwallet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import { randomBytes } from '../crypto/index.js';
import { getZoneForAddress, isQiAddress } from '../utils/index.js';
import { Zone } from '../constants/index.js';
import { TransactionRequest, Provider, TransactionResponse } from '../providers/index.js';
import { AllowedCoinType } from "../constants/index.js";
import { QiHDWallet } from "./qi-hdwallet.js";
import { QuaiHDWallet } from "./quai-hdwallet.js";
import { AllowedCoinType } from '../constants/index.js';
import { QiHDWallet } from './qi-hdwallet.js';
import { QuaiHDWallet } from './quai-hdwallet.js';

export interface NeuteredAddressInfo {
pubKey: string;
Expand All @@ -19,22 +19,20 @@ export interface NeuteredAddressInfo {
zone: Zone;
}

export interface SerializedHDWallet {
export interface SerializedHDWallet {
version: number;
phrase: string;
coinType: AllowedCoinType;
addresses: Array<NeuteredAddressInfo>;
provider: Provider | undefined;
};
}

// Constant to represent the maximum attempt to derive an address
const MAX_ADDRESS_DERIVATION_ATTEMPTS = 10000000;

export abstract class AbstractHDWallet {

protected static _version: number = 1;

protected static _coinType?: AllowedCoinType;
protected static _coinType?: AllowedCoinType;

// Map of account number to HDNodeWallet
protected _accounts: Map<number, HDNodeWallet> = new Map();
Expand All @@ -58,10 +56,10 @@ export abstract class AbstractHDWallet {
protected static parentPath(coinType: number): string {
return `m/44'/${coinType}'`;
}
protected coinType(): AllowedCoinType {
return (this.constructor as typeof AbstractHDWallet)._coinType!;
}

protected coinType(): AllowedCoinType {
return (this.constructor as typeof AbstractHDWallet)._coinType!;
}

// helper methods that adds an account HD node to the HD wallet following the BIP-44 standard.
protected addAccount(accountIndex: number): void {
Expand All @@ -82,9 +80,9 @@ export abstract class AbstractHDWallet {
return false;
}
const isCorrectShard = addressZone === zone;
const isCorrectLedger = (this.coinType() === 969) ? isQiAddress(address) : !isQiAddress(address);
const isCorrectLedger = this.coinType() === 969 ? isQiAddress(address) : !isQiAddress(address);
return isCorrectShard && isCorrectLedger;
}
};
// derive the address node
const accountNode = this._accounts.get(account);
const changeIndex = isChange ? 1 : 0;
Expand All @@ -96,19 +94,26 @@ export abstract class AbstractHDWallet {
addrIndex++;
// put a hard limit on the number of addresses to derive
if (addrIndex - startingIndex > MAX_ADDRESS_DERIVATION_ATTEMPTS) {
throw new Error(`Failed to derive a valid address for the zone ${zone} after ${MAX_ADDRESS_DERIVATION_ATTEMPTS} attempts.`);
throw new Error(
`Failed to derive a valid address for the zone ${zone} after ${MAX_ADDRESS_DERIVATION_ATTEMPTS} attempts.`,
);
}
} while (!isValidAddressForZone(addressNode.address));

return addressNode;
}

public addAddress(account: number, addressIndex: number, isChange: boolean = false): NeuteredAddressInfo {
public addAddress(account: number, addressIndex: number, isChange: boolean = false): NeuteredAddressInfo {
return this._addAddress(this._addresses, account, addressIndex, isChange);
}

// helper method to add an address to the wallet address map
protected _addAddress(addressMap: Map<string, NeuteredAddressInfo> ,account: number, addressIndex: number, isChange: boolean = false): NeuteredAddressInfo {
protected _addAddress(
addressMap: Map<string, NeuteredAddressInfo>,
account: number,
addressIndex: number,
isChange: boolean = false,
): NeuteredAddressInfo {
if (!this._accounts.has(account)) {
this.addAccount(account);
}
Expand Down Expand Up @@ -142,7 +147,7 @@ export abstract class AbstractHDWallet {
return neuteredAddressInfo;
}

public getNextAddress(accountIndex: number, zone: Zone): NeuteredAddressInfo {
public getNextAddress(accountIndex: number, zone: Zone): NeuteredAddressInfo {
this.validateZone(zone);
if (!this._accounts.has(accountIndex)) {
this.addAccount(accountIndex);
Expand Down Expand Up @@ -203,16 +208,33 @@ export abstract class AbstractHDWallet {
return (this as any).createInstance(mnemonic);
}

static createRandom<T extends AbstractHDWallet>(this: new (root: HDNodeWallet) => T, password?: string, wordlist?: Wordlist): T {
if (password == null) { password = ""; }
if (wordlist == null) { wordlist = LangEn.wordlist(); }
static createRandom<T extends AbstractHDWallet>(
this: new (root: HDNodeWallet) => T,
password?: string,
wordlist?: Wordlist,
): T {
if (password == null) {
password = '';
}
if (wordlist == null) {
wordlist = LangEn.wordlist();
}
const mnemonic = Mnemonic.fromEntropy(randomBytes(16), password, wordlist);
return (this as any).createInstance(mnemonic);
}

static fromPhrase<T extends AbstractHDWallet>(this: new (root: HDNodeWallet) => T, phrase: string, password?: string, wordlist?: Wordlist): T {
if (password == null) { password = ""; }
if (wordlist == null) { wordlist = LangEn.wordlist(); }
static fromPhrase<T extends AbstractHDWallet>(
this: new (root: HDNodeWallet) => T,
phrase: string,
password?: string,
wordlist?: Wordlist,
): T {
if (password == null) {
password = '';
}
if (wordlist == null) {
wordlist = LangEn.wordlist();
}
const mnemonic = Mnemonic.fromPhrase(phrase, password, wordlist);
return (this as any).createInstance(mnemonic);
}
Expand Down Expand Up @@ -244,26 +266,26 @@ export abstract class AbstractHDWallet {
if (!accountNode) {
throw new Error(`Account ${addressInfo.account} not found`);
}
const changeNode = accountNode.deriveChild(0);
const changeIndex = addressInfo.change ? 1 : 0;
const changeNode = accountNode.deriveChild(changeIndex);
return changeNode.deriveChild(addressInfo.index);
}

abstract signMessage(address: string, message: string | Uint8Array): Promise<string>
public async serialize(): Promise<SerializedHDWallet> {
abstract signMessage(address: string, message: string | Uint8Array): Promise<string>;

public async serialize(): Promise<SerializedHDWallet> {
const addresses = Array.from(this._addresses.values());
return {
version: (this.constructor as any)._version,
phrase: this._root.mnemonic!.phrase,
coinType: this.coinType(),
addresses: addresses,
provider: this.provider,
};
}

static async deserialize<T extends AbstractHDWallet>(
this: new (root: HDNodeWallet, provider?: Provider) => T,
serialized: SerializedHDWallet
serialized: SerializedHDWallet,
): Promise<QuaiHDWallet | QiHDWallet> {
// validate the version and coinType
if (serialized.version !== (this as any)._version) {
Expand All @@ -275,20 +297,28 @@ export abstract class AbstractHDWallet {
// create the wallet instance
const mnemonic = Mnemonic.fromPhrase(serialized.phrase);
const path = (this as any).parentPath(serialized.coinType);
const root = HDNodeWallet.fromMnemonic(mnemonic, path );
const wallet = new this(root, serialized.provider);
const root = HDNodeWallet.fromMnemonic(mnemonic, path);
const wallet = new this(root);

// import the addresses
wallet.importSerializedAddresses(wallet._addresses,serialized.addresses);
wallet.importSerializedAddresses(wallet._addresses, serialized.addresses);

return wallet as T;
}

// This method is used to import addresses from a serialized wallet.
// It validates the addresses and adds them to the wallet.
protected importSerializedAddresses(addressMap: Map<string, NeuteredAddressInfo>, addresses: NeuteredAddressInfo[]): void {
protected importSerializedAddresses(
addressMap: Map<string, NeuteredAddressInfo>,
addresses: NeuteredAddressInfo[],
): void {
for (const addressInfo of addresses) {
const newAddressInfo = this._addAddress(addressMap, addressInfo.account, addressInfo.index, addressInfo.change);
const newAddressInfo = this._addAddress(
addressMap,
addressInfo.account,
addressInfo.index,
addressInfo.change,
);
// validate the address info
if (addressInfo.address !== newAddressInfo.address) {
throw new Error(`Address mismatch: ${addressInfo.address} != ${newAddressInfo.address}`);
Expand Down

0 comments on commit 0d5cdb0

Please sign in to comment.