diff --git a/src/wallet/hdwallet.ts b/src/wallet/hdwallet.ts index 69e2f499..9a399220 100644 --- a/src/wallet/hdwallet.ts +++ b/src/wallet/hdwallet.ts @@ -6,7 +6,9 @@ import { randomBytes } from '../crypto/index.js'; import { getZoneForAddress, isQiAddress } from '../utils/index.js'; import { Zone } from '../constants/index.js'; import { TransactionRequest, Provider, TransactionResponse } from '../providers/index.js'; -import { AllowedCoinType } from '../constants/index.js'; +import { AllowedCoinType } from "../constants/index.js"; +import { QiHDWallet } from "./qi-hdwallet.js"; +import { QuaiHDWallet } from "./quai-hdwallet.js"; export interface NeuteredAddressInfo { pubKey: string; @@ -17,11 +19,22 @@ export interface NeuteredAddressInfo { zone: Zone; } +export interface SerializedHDWallet { + version: number; + phrase: string; + coinType: AllowedCoinType; + addresses: Array; + provider: Provider | undefined; + }; + // Constant to represent the maximum attempt to derive an address const MAX_ADDRESS_DERIVATION_ATTEMPTS = 10000000; export abstract class AbstractHDWallet { - protected static _coinType?: AllowedCoinType; + + protected static _version: number = 1; + + protected static _coinType?: AllowedCoinType; // Map of account number to HDNodeWallet protected _accounts: Map = new Map(); @@ -45,10 +58,10 @@ export abstract class AbstractHDWallet { protected static parentPath(coinType: number): string { return `m/44'/${coinType}'`; } - - protected coinType(): number { - return (this.constructor as typeof AbstractHDWallet)._coinType!; - } + + protected coinType(): AllowedCoinType { + return (this.constructor as typeof AbstractHDWallet)._coinType!; + } // helper methods that adds an account HD node to the HD wallet following the BIP-44 standard. protected addAccount(accountIndex: number): void { @@ -69,9 +82,9 @@ export abstract class AbstractHDWallet { return false; } const isCorrectShard = addressZone === zone; - const isCorrectLedger = this.coinType() === 969 ? isQiAddress(address) : !isQiAddress(address); + const isCorrectLedger = (this.coinType() === 969) ? isQiAddress(address) : !isQiAddress(address); return isCorrectShard && isCorrectLedger; - }; + } // derive the address node const accountNode = this._accounts.get(account); const changeIndex = isChange ? 1 : 0; @@ -83,16 +96,19 @@ export abstract class AbstractHDWallet { addrIndex++; // put a hard limit on the number of addresses to derive if (addrIndex - startingIndex > MAX_ADDRESS_DERIVATION_ATTEMPTS) { - throw new Error( - `Failed to derive a valid address for the zone ${zone} after ${MAX_ADDRESS_DERIVATION_ATTEMPTS} attempts.`, - ); + throw new Error(`Failed to derive a valid address for the zone ${zone} after ${MAX_ADDRESS_DERIVATION_ATTEMPTS} attempts.`); } } while (!isValidAddressForZone(addressNode.address)); return addressNode; } - public addAddress(account: number, addressIndex: number, zone: Zone): NeuteredAddressInfo { + public addAddress(account: number, addressIndex: number, isChange: boolean = false): NeuteredAddressInfo { + return this._addAddress(this._addresses, account, addressIndex, isChange); + } + + // helper method to add an address to the wallet address map + protected _addAddress(addressMap: Map ,account: number, addressIndex: number, isChange: boolean = false): NeuteredAddressInfo { if (!this._accounts.has(account)) { this.addAccount(account); } @@ -103,23 +119,30 @@ export abstract class AbstractHDWallet { } }); - const addressNode = this.deriveAddress(account, addressIndex, zone); + // derive the address node + const changeIndex = isChange ? 1 : 0; + const addressNode = this._root.deriveChild(account).deriveChild(changeIndex).deriveChild(addressIndex); + const zone = getZoneForAddress(addressNode.address); + if (!zone) { + throw new Error(`Failed to derive a valid address zone for the index ${addressIndex}`); + } - // create the NeuteredAddressInfo object and update the maps + // create the NeuteredAddressInfo object and update the map const neuteredAddressInfo = { pubKey: addressNode.publicKey, address: addressNode.address, account: account, index: addressNode.index, - change: false, + change: isChange, zone: zone, }; - this._addresses.set(neuteredAddressInfo.address, neuteredAddressInfo); + addressMap.set(neuteredAddressInfo.address, neuteredAddressInfo); return neuteredAddressInfo; } - public getNextAddress(accountIndex: number, zone: Zone): NeuteredAddressInfo { + + public getNextAddress(accountIndex: number, zone: Zone): NeuteredAddressInfo { this.validateZone(zone); if (!this._accounts.has(accountIndex)) { this.addAccount(accountIndex); @@ -180,33 +203,16 @@ export abstract class AbstractHDWallet { return (this as any).createInstance(mnemonic); } - static createRandom( - this: new (root: HDNodeWallet) => T, - password?: string, - wordlist?: Wordlist, - ): T { - if (password == null) { - password = ''; - } - if (wordlist == null) { - wordlist = LangEn.wordlist(); - } + static createRandom(this: new (root: HDNodeWallet) => T, password?: string, wordlist?: Wordlist): T { + if (password == null) { password = ""; } + if (wordlist == null) { wordlist = LangEn.wordlist(); } const mnemonic = Mnemonic.fromEntropy(randomBytes(16), password, wordlist); return (this as any).createInstance(mnemonic); } - static fromPhrase( - this: new (root: HDNodeWallet) => T, - phrase: string, - password?: string, - wordlist?: Wordlist, - ): T { - if (password == null) { - password = ''; - } - if (wordlist == null) { - wordlist = LangEn.wordlist(); - } + static fromPhrase(this: new (root: HDNodeWallet) => T, phrase: string, password?: string, wordlist?: Wordlist): T { + if (password == null) { password = ""; } + if (wordlist == null) { wordlist = LangEn.wordlist(); } const mnemonic = Mnemonic.fromPhrase(phrase, password, wordlist); return (this as any).createInstance(mnemonic); } @@ -224,4 +230,75 @@ export abstract class AbstractHDWallet { throw new Error(`Invalid zone: ${zone}`); } } + + // Returns the HD node that derives the address. + // If the address is not found in the wallet, an error is thrown. + protected _getHDNodeForAddress(addr: string): HDNodeWallet { + const addressInfo = this._addresses.get(addr); + if (!addressInfo) { + throw new Error(`Address ${addr} is not known to this wallet`); + } + + // derive a HD node for the from address using the index + const accountNode = this._accounts.get(addressInfo.account); + if (!accountNode) { + throw new Error(`Account ${addressInfo.account} not found`); + } + const changeNode = accountNode.deriveChild(0); + return changeNode.deriveChild(addressInfo.index); + } + + abstract signMessage(address: string, message: string | Uint8Array): Promise + + public async serialize(): Promise { + const addresses = Array.from(this._addresses.values()); + return { + version: (this.constructor as any)._version, + phrase: this._root.mnemonic!.phrase, + coinType: this.coinType(), + addresses: addresses, + provider: this.provider, + }; + } + + static async deserialize( + this: new (root: HDNodeWallet, provider?: Provider) => T, + serialized: SerializedHDWallet + ): Promise { + // validate the version and coinType + if (serialized.version !== (this as any)._version) { + throw new Error(`Invalid version ${serialized.version} for wallet (expected ${(this as any)._version})`); + } + if (serialized.coinType !== (this as any)._coinType) { + throw new Error(`Invalid coinType ${serialized.coinType} for wallet (expected ${(this as any)._coinType})`); + } + // create the wallet instance + const mnemonic = Mnemonic.fromPhrase(serialized.phrase); + const path = (this as any).parentPath(serialized.coinType); + const root = HDNodeWallet.fromMnemonic(mnemonic, path ); + const wallet = new this(root, serialized.provider); + + // import the addresses + wallet.importSerializedAddresses(wallet._addresses,serialized.addresses); + + return wallet as T; + } + + // This method is used to import addresses from a serialized wallet. + // It validates the addresses and adds them to the wallet. + protected importSerializedAddresses(addressMap: Map, addresses: NeuteredAddressInfo[]): void { + for (const addressInfo of addresses) { + const newAddressInfo = this._addAddress(addressMap, addressInfo.account, addressInfo.index, addressInfo.change); + // validate the address info + if (addressInfo.address !== newAddressInfo.address) { + throw new Error(`Address mismatch: ${addressInfo.address} != ${newAddressInfo.address}`); + } + if (addressInfo.pubKey !== newAddressInfo.pubKey) { + throw new Error(`Public key mismatch: ${addressInfo.pubKey} != ${newAddressInfo.pubKey}`); + } + if (addressInfo.zone !== newAddressInfo.zone) { + throw new Error(`Zone mismatch: ${addressInfo.zone} != ${newAddressInfo.zone}`); + } + } + } } diff --git a/src/wallet/qi-hdwallet.ts b/src/wallet/qi-hdwallet.ts index 03b37197..d7a03acb 100644 --- a/src/wallet/qi-hdwallet.ts +++ b/src/wallet/qi-hdwallet.ts @@ -1,5 +1,7 @@ -import { AbstractHDWallet, NeuteredAddressInfo } from './hdwallet.js'; -import { HDNodeWallet } from './hdnodewallet.js'; + + +import { AbstractHDWallet, NeuteredAddressInfo, SerializedHDWallet } from './hdwallet'; +import { HDNodeWallet } from "./hdnodewallet"; import { QiTransactionRequest, Provider, TransactionResponse } from '../providers/index.js'; import { computeAddress } from '../address/index.js'; import { getBytes, hexlify } from '../utils/index.js'; @@ -19,7 +21,17 @@ type OutpointInfo = { account?: number; }; +interface SerializedQiHDWallet extends SerializedHDWallet{ + outpoints: OutpointInfo[]; + changeAddresses: NeuteredAddressInfo[]; + gapAddresses: NeuteredAddressInfo[]; + gapChangeAddresses: NeuteredAddressInfo[]; +} + export class QiHDWallet extends AbstractHDWallet { + + protected static _version: number = 1; + protected static _GAP_LIMIT: number = 20; protected static _coinType: AllowedCoinType = 969; @@ -307,9 +319,71 @@ export class QiHDWallet extends AbstractHDWallet { return gapAddresses; } - public getGapChangeAddressesForZone(zone: Zone): NeuteredAddressInfo[] { - this.validateZone(zone); - const gapChangeAddresses = this._gapChangeAddresses.filter((addressInfo) => addressInfo.zone === zone); - return gapChangeAddresses; - } + public getGapChangeAddressesForZone(zone: Zone): NeuteredAddressInfo[] { + this.validateZone(zone); + const gapChangeAddresses = this._gapChangeAddresses.filter((addressInfo) => addressInfo.zone === zone); + return gapChangeAddresses; + } + + public async signMessage(address: string, message: string | Uint8Array): Promise { + const addrNode = this._getHDNodeForAddress(address); + const privKey = addrNode.privateKey; + const digest = keccak_256(message); + const signature = schnorr.sign(digest, getBytes(privKey)); + return hexlify(signature); + } + + public async serialize(): Promise { + const hdwalletSerialized = await super.serialize(); + return { + outpoints: this._outpoints, + changeAddresses: Array.from(this._changeAddresses.values()), + gapAddresses: this._gapAddresses, + gapChangeAddresses: this._gapChangeAddresses, + ...hdwalletSerialized, + }; + } + + public static async deserialize(serialized: SerializedQiHDWallet): Promise { + const wallet = await super.deserialize(serialized) as QiHDWallet; + // import the change addresses + wallet.importSerializedAddresses(wallet._changeAddresses, serialized.changeAddresses); + + // import the gap addresses, verifying they exist in the wallet + for (const gapAddressInfo of serialized.gapAddresses) { + const gapAddress = gapAddressInfo.address; + if (!wallet._addresses.has(gapAddress)) { + throw new Error(`Address ${gapAddress} not found in wallet`); + } + wallet._gapAddresses.push(gapAddressInfo); + + } + // import the gap change addresses, verifying they exist in the wallet + for (const gapChangeAddressInfo of serialized.gapChangeAddresses) { + const gapChangeAddress = gapChangeAddressInfo.address; + if (!wallet._changeAddresses.has(gapChangeAddress)) { + throw new Error(`Address ${gapChangeAddress} not found in wallet`); + } + wallet._gapChangeAddresses.push(gapChangeAddressInfo); + } + + // validate the outpoints and import them + for (const outpointInfo of serialized.outpoints) { + // check the zone is valid + wallet.validateZone(outpointInfo.zone); + // check the outpoint address is known to the wallet + if (!wallet._addresses.has(outpointInfo.address)) { + throw new Error(`Address ${outpointInfo.address} not found in wallet`); + } + const outpoint = outpointInfo.outpoint; + // TODO: implement a more robust check for Outpoint + // check the Outpoint fields are not empty + if (outpoint.Txhash == null || outpoint.Index == null || outpoint.Denomination == null) { + throw new Error(`Invalid Outpoint: ${JSON.stringify(outpoint)} `); + } + wallet._outpoints.push(outpointInfo); + } + return wallet; + + } } diff --git a/src/wallet/quai-hdwallet.ts b/src/wallet/quai-hdwallet.ts index 41002c62..d2199da0 100644 --- a/src/wallet/quai-hdwallet.ts +++ b/src/wallet/quai-hdwallet.ts @@ -5,30 +5,18 @@ import { resolveAddress } from '../address/index.js'; import { AllowedCoinType } from '../constants/index.js'; export class QuaiHDWallet extends AbstractHDWallet { + + protected static _version: number = 1; + protected static _coinType: AllowedCoinType = 994; private constructor(root: HDNodeWallet, provider?: Provider) { super(root, provider); } - private _getHDNode(from: string): HDNodeWallet { - const fromAddressInfo = this._addresses.get(from); - if (!fromAddressInfo) { - throw new Error(`Address ${from} is not known to wallet`); - } - - // derive a HD node for the from address using the index - const accountNode = this._accounts.get(fromAddressInfo.account); - if (!accountNode) { - throw new Error(`Account ${fromAddressInfo.account} not found`); - } - const changeNode = accountNode.deriveChild(0); - return changeNode.deriveChild(fromAddressInfo.index); - } - public async signTransaction(tx: QuaiTransactionRequest): Promise { const from = await resolveAddress(tx.from); - const fromNode = this._getHDNode(from); + const fromNode = this._getHDNodeForAddress(from); const signedTx = await fromNode.signTransaction(tx); return signedTx; } @@ -38,8 +26,13 @@ export class QuaiHDWallet extends AbstractHDWallet { throw new Error('Provider is not set'); } const from = await resolveAddress(tx.from); - const fromNode = this._getHDNode(from); + const fromNode = this._getHDNodeForAddress(from); const fromNodeConnected = fromNode.connect(this.provider); return await fromNodeConnected.sendTransaction(tx); } + + public async signMessage(address: string, message: string | Uint8Array): Promise { + const addrNode = this._getHDNodeForAddress(address); + return await addrNode.signMessage(message); + } }