Skip to content

Commit

Permalink
add methods 'signMessage' and 'signMessageBytes'
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoacosta74 authored and rileystephens28 committed Jun 17, 2024
1 parent 4c6fb40 commit 872539e
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 64 deletions.
157 changes: 117 additions & 40 deletions src/wallet/hdwallet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import { randomBytes } from '../crypto/index.js';
import { getZoneForAddress, isQiAddress } from '../utils/index.js';
import { Zone } from '../constants/index.js';
import { TransactionRequest, Provider, TransactionResponse } from '../providers/index.js';
import { AllowedCoinType } from '../constants/index.js';
import { AllowedCoinType } from "../constants/index.js";
import { QiHDWallet } from "./qi-hdwallet.js";
import { QuaiHDWallet } from "./quai-hdwallet.js";

export interface NeuteredAddressInfo {
pubKey: string;
Expand All @@ -17,11 +19,22 @@ export interface NeuteredAddressInfo {
zone: Zone;
}

export interface SerializedHDWallet {
version: number;
phrase: string;
coinType: AllowedCoinType;
addresses: Array<NeuteredAddressInfo>;
provider: Provider | undefined;
};

// Constant to represent the maximum attempt to derive an address
const MAX_ADDRESS_DERIVATION_ATTEMPTS = 10000000;

export abstract class AbstractHDWallet {
protected static _coinType?: AllowedCoinType;

protected static _version: number = 1;

protected static _coinType?: AllowedCoinType;

// Map of account number to HDNodeWallet
protected _accounts: Map<number, HDNodeWallet> = new Map();
Expand All @@ -45,10 +58,10 @@ export abstract class AbstractHDWallet {
protected static parentPath(coinType: number): string {
return `m/44'/${coinType}'`;
}

protected coinType(): number {
return (this.constructor as typeof AbstractHDWallet)._coinType!;
}
protected coinType(): AllowedCoinType {
return (this.constructor as typeof AbstractHDWallet)._coinType!;
}

// helper methods that adds an account HD node to the HD wallet following the BIP-44 standard.
protected addAccount(accountIndex: number): void {
Expand All @@ -69,9 +82,9 @@ export abstract class AbstractHDWallet {
return false;
}
const isCorrectShard = addressZone === zone;
const isCorrectLedger = this.coinType() === 969 ? isQiAddress(address) : !isQiAddress(address);
const isCorrectLedger = (this.coinType() === 969) ? isQiAddress(address) : !isQiAddress(address);
return isCorrectShard && isCorrectLedger;
};
}
// derive the address node
const accountNode = this._accounts.get(account);
const changeIndex = isChange ? 1 : 0;
Expand All @@ -83,16 +96,19 @@ export abstract class AbstractHDWallet {
addrIndex++;
// put a hard limit on the number of addresses to derive
if (addrIndex - startingIndex > MAX_ADDRESS_DERIVATION_ATTEMPTS) {
throw new Error(
`Failed to derive a valid address for the zone ${zone} after ${MAX_ADDRESS_DERIVATION_ATTEMPTS} attempts.`,
);
throw new Error(`Failed to derive a valid address for the zone ${zone} after ${MAX_ADDRESS_DERIVATION_ATTEMPTS} attempts.`);
}
} while (!isValidAddressForZone(addressNode.address));

return addressNode;
}

public addAddress(account: number, addressIndex: number, zone: Zone): NeuteredAddressInfo {
public addAddress(account: number, addressIndex: number, isChange: boolean = false): NeuteredAddressInfo {
return this._addAddress(this._addresses, account, addressIndex, isChange);
}

// helper method to add an address to the wallet address map
protected _addAddress(addressMap: Map<string, NeuteredAddressInfo> ,account: number, addressIndex: number, isChange: boolean = false): NeuteredAddressInfo {
if (!this._accounts.has(account)) {
this.addAccount(account);
}
Expand All @@ -103,23 +119,30 @@ export abstract class AbstractHDWallet {
}
});

const addressNode = this.deriveAddress(account, addressIndex, zone);
// derive the address node
const changeIndex = isChange ? 1 : 0;
const addressNode = this._root.deriveChild(account).deriveChild(changeIndex).deriveChild(addressIndex);
const zone = getZoneForAddress(addressNode.address);
if (!zone) {
throw new Error(`Failed to derive a valid address zone for the index ${addressIndex}`);
}

// create the NeuteredAddressInfo object and update the maps
// create the NeuteredAddressInfo object and update the map
const neuteredAddressInfo = {
pubKey: addressNode.publicKey,
address: addressNode.address,
account: account,
index: addressNode.index,
change: false,
change: isChange,
zone: zone,
};

this._addresses.set(neuteredAddressInfo.address, neuteredAddressInfo);
addressMap.set(neuteredAddressInfo.address, neuteredAddressInfo);

return neuteredAddressInfo;
}
public getNextAddress(accountIndex: number, zone: Zone): NeuteredAddressInfo {

public getNextAddress(accountIndex: number, zone: Zone): NeuteredAddressInfo {
this.validateZone(zone);
if (!this._accounts.has(accountIndex)) {
this.addAccount(accountIndex);
Expand Down Expand Up @@ -180,33 +203,16 @@ export abstract class AbstractHDWallet {
return (this as any).createInstance(mnemonic);
}

static createRandom<T extends AbstractHDWallet>(
this: new (root: HDNodeWallet) => T,
password?: string,
wordlist?: Wordlist,
): T {
if (password == null) {
password = '';
}
if (wordlist == null) {
wordlist = LangEn.wordlist();
}
static createRandom<T extends AbstractHDWallet>(this: new (root: HDNodeWallet) => T, password?: string, wordlist?: Wordlist): T {
if (password == null) { password = ""; }
if (wordlist == null) { wordlist = LangEn.wordlist(); }
const mnemonic = Mnemonic.fromEntropy(randomBytes(16), password, wordlist);
return (this as any).createInstance(mnemonic);
}

static fromPhrase<T extends AbstractHDWallet>(
this: new (root: HDNodeWallet) => T,
phrase: string,
password?: string,
wordlist?: Wordlist,
): T {
if (password == null) {
password = '';
}
if (wordlist == null) {
wordlist = LangEn.wordlist();
}
static fromPhrase<T extends AbstractHDWallet>(this: new (root: HDNodeWallet) => T, phrase: string, password?: string, wordlist?: Wordlist): T {
if (password == null) { password = ""; }
if (wordlist == null) { wordlist = LangEn.wordlist(); }
const mnemonic = Mnemonic.fromPhrase(phrase, password, wordlist);
return (this as any).createInstance(mnemonic);
}
Expand All @@ -224,4 +230,75 @@ export abstract class AbstractHDWallet {
throw new Error(`Invalid zone: ${zone}`);
}
}

// Returns the HD node that derives the address.
// If the address is not found in the wallet, an error is thrown.
protected _getHDNodeForAddress(addr: string): HDNodeWallet {
const addressInfo = this._addresses.get(addr);
if (!addressInfo) {
throw new Error(`Address ${addr} is not known to this wallet`);
}

// derive a HD node for the from address using the index
const accountNode = this._accounts.get(addressInfo.account);
if (!accountNode) {
throw new Error(`Account ${addressInfo.account} not found`);
}
const changeNode = accountNode.deriveChild(0);
return changeNode.deriveChild(addressInfo.index);
}

abstract signMessage(address: string, message: string | Uint8Array): Promise<string>

public async serialize(): Promise<SerializedHDWallet> {
const addresses = Array.from(this._addresses.values());
return {
version: (this.constructor as any)._version,
phrase: this._root.mnemonic!.phrase,
coinType: this.coinType(),
addresses: addresses,
provider: this.provider,
};
}

static async deserialize<T extends AbstractHDWallet>(
this: new (root: HDNodeWallet, provider?: Provider) => T,
serialized: SerializedHDWallet
): Promise<QuaiHDWallet | QiHDWallet> {
// validate the version and coinType
if (serialized.version !== (this as any)._version) {
throw new Error(`Invalid version ${serialized.version} for wallet (expected ${(this as any)._version})`);
}
if (serialized.coinType !== (this as any)._coinType) {
throw new Error(`Invalid coinType ${serialized.coinType} for wallet (expected ${(this as any)._coinType})`);
}
// create the wallet instance
const mnemonic = Mnemonic.fromPhrase(serialized.phrase);
const path = (this as any).parentPath(serialized.coinType);
const root = HDNodeWallet.fromMnemonic(mnemonic, path );
const wallet = new this(root, serialized.provider);

// import the addresses
wallet.importSerializedAddresses(wallet._addresses,serialized.addresses);

return wallet as T;
}

// This method is used to import addresses from a serialized wallet.
// It validates the addresses and adds them to the wallet.
protected importSerializedAddresses(addressMap: Map<string, NeuteredAddressInfo>, addresses: NeuteredAddressInfo[]): void {
for (const addressInfo of addresses) {
const newAddressInfo = this._addAddress(addressMap, addressInfo.account, addressInfo.index, addressInfo.change);
// validate the address info
if (addressInfo.address !== newAddressInfo.address) {
throw new Error(`Address mismatch: ${addressInfo.address} != ${newAddressInfo.address}`);
}
if (addressInfo.pubKey !== newAddressInfo.pubKey) {
throw new Error(`Public key mismatch: ${addressInfo.pubKey} != ${newAddressInfo.pubKey}`);
}
if (addressInfo.zone !== newAddressInfo.zone) {
throw new Error(`Zone mismatch: ${addressInfo.zone} != ${newAddressInfo.zone}`);
}
}
}
}
88 changes: 81 additions & 7 deletions src/wallet/qi-hdwallet.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { AbstractHDWallet, NeuteredAddressInfo } from './hdwallet.js';
import { HDNodeWallet } from './hdnodewallet.js';


import { AbstractHDWallet, NeuteredAddressInfo, SerializedHDWallet } from './hdwallet';
import { HDNodeWallet } from "./hdnodewallet";
import { QiTransactionRequest, Provider, TransactionResponse } from '../providers/index.js';
import { computeAddress } from '../address/index.js';
import { getBytes, hexlify } from '../utils/index.js';
Expand All @@ -19,7 +21,17 @@ type OutpointInfo = {
account?: number;
};

interface SerializedQiHDWallet extends SerializedHDWallet{
outpoints: OutpointInfo[];
changeAddresses: NeuteredAddressInfo[];
gapAddresses: NeuteredAddressInfo[];
gapChangeAddresses: NeuteredAddressInfo[];
}

export class QiHDWallet extends AbstractHDWallet {

protected static _version: number = 1;

protected static _GAP_LIMIT: number = 20;

protected static _coinType: AllowedCoinType = 969;
Expand Down Expand Up @@ -307,9 +319,71 @@ export class QiHDWallet extends AbstractHDWallet {
return gapAddresses;
}

public getGapChangeAddressesForZone(zone: Zone): NeuteredAddressInfo[] {
this.validateZone(zone);
const gapChangeAddresses = this._gapChangeAddresses.filter((addressInfo) => addressInfo.zone === zone);
return gapChangeAddresses;
}
public getGapChangeAddressesForZone(zone: Zone): NeuteredAddressInfo[] {
this.validateZone(zone);
const gapChangeAddresses = this._gapChangeAddresses.filter((addressInfo) => addressInfo.zone === zone);
return gapChangeAddresses;
}

public async signMessage(address: string, message: string | Uint8Array): Promise<string> {
const addrNode = this._getHDNodeForAddress(address);
const privKey = addrNode.privateKey;
const digest = keccak_256(message);
const signature = schnorr.sign(digest, getBytes(privKey));
return hexlify(signature);
}

public async serialize(): Promise<SerializedQiHDWallet> {
const hdwalletSerialized = await super.serialize();
return {
outpoints: this._outpoints,
changeAddresses: Array.from(this._changeAddresses.values()),
gapAddresses: this._gapAddresses,
gapChangeAddresses: this._gapChangeAddresses,
...hdwalletSerialized,
};
}

public static async deserialize(serialized: SerializedQiHDWallet): Promise<QiHDWallet> {
const wallet = await super.deserialize<QiHDWallet>(serialized) as QiHDWallet;
// import the change addresses
wallet.importSerializedAddresses(wallet._changeAddresses, serialized.changeAddresses);

// import the gap addresses, verifying they exist in the wallet
for (const gapAddressInfo of serialized.gapAddresses) {
const gapAddress = gapAddressInfo.address;
if (!wallet._addresses.has(gapAddress)) {
throw new Error(`Address ${gapAddress} not found in wallet`);
}
wallet._gapAddresses.push(gapAddressInfo);

}
// import the gap change addresses, verifying they exist in the wallet
for (const gapChangeAddressInfo of serialized.gapChangeAddresses) {
const gapChangeAddress = gapChangeAddressInfo.address;
if (!wallet._changeAddresses.has(gapChangeAddress)) {
throw new Error(`Address ${gapChangeAddress} not found in wallet`);
}
wallet._gapChangeAddresses.push(gapChangeAddressInfo);
}

// validate the outpoints and import them
for (const outpointInfo of serialized.outpoints) {
// check the zone is valid
wallet.validateZone(outpointInfo.zone);
// check the outpoint address is known to the wallet
if (!wallet._addresses.has(outpointInfo.address)) {
throw new Error(`Address ${outpointInfo.address} not found in wallet`);
}
const outpoint = outpointInfo.outpoint;
// TODO: implement a more robust check for Outpoint
// check the Outpoint fields are not empty
if (outpoint.Txhash == null || outpoint.Index == null || outpoint.Denomination == null) {
throw new Error(`Invalid Outpoint: ${JSON.stringify(outpoint)} `);
}
wallet._outpoints.push(outpointInfo);
}
return wallet;

}
}
27 changes: 10 additions & 17 deletions src/wallet/quai-hdwallet.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,18 @@ import { resolveAddress } from '../address/index.js';
import { AllowedCoinType } from '../constants/index.js';

export class QuaiHDWallet extends AbstractHDWallet {

protected static _version: number = 1;

protected static _coinType: AllowedCoinType = 994;

private constructor(root: HDNodeWallet, provider?: Provider) {
super(root, provider);
}

private _getHDNode(from: string): HDNodeWallet {
const fromAddressInfo = this._addresses.get(from);
if (!fromAddressInfo) {
throw new Error(`Address ${from} is not known to wallet`);
}

// derive a HD node for the from address using the index
const accountNode = this._accounts.get(fromAddressInfo.account);
if (!accountNode) {
throw new Error(`Account ${fromAddressInfo.account} not found`);
}
const changeNode = accountNode.deriveChild(0);
return changeNode.deriveChild(fromAddressInfo.index);
}

public async signTransaction(tx: QuaiTransactionRequest): Promise<string> {
const from = await resolveAddress(tx.from);
const fromNode = this._getHDNode(from);
const fromNode = this._getHDNodeForAddress(from);
const signedTx = await fromNode.signTransaction(tx);
return signedTx;
}
Expand All @@ -38,8 +26,13 @@ export class QuaiHDWallet extends AbstractHDWallet {
throw new Error('Provider is not set');
}
const from = await resolveAddress(tx.from);
const fromNode = this._getHDNode(from);
const fromNode = this._getHDNodeForAddress(from);
const fromNodeConnected = fromNode.connect(this.provider);
return await fromNodeConnected.sendTransaction(tx);
}

public async signMessage(address: string, message: string | Uint8Array): Promise<string> {
const addrNode = this._getHDNodeForAddress(address);
return await addrNode.signMessage(message);
}
}

0 comments on commit 872539e

Please sign in to comment.