diff --git a/cmd/chantools/signpsbt.go b/cmd/chantools/signpsbt.go index 927881b..50078c4 100644 --- a/cmd/chantools/signpsbt.go +++ b/cmd/chantools/signpsbt.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/base64" "encoding/binary" + "errors" "fmt" "os" @@ -15,6 +16,10 @@ import ( "github.com/spf13/cobra" ) +var ( + errNoPathFound = fmt.Errorf("no matching derivation path found") +) + type signPSBTCommand struct { Psbt string FromRawPsbtFile string @@ -134,93 +139,117 @@ func (c *signPSBTCommand) Execute(_ *cobra.Command, _ []string) error { func signPsbt(rootKey *hdkeychain.ExtendedKey, packet *psbt.Packet, signer *lnd.Signer) error { - // Check that we have an input with a derivation path that belongs to - // the root key. - derivationPath, inputIndex, err := findMatchingDerivationPath( - rootKey, packet, - ) - if err != nil { - return fmt.Errorf("could not find matching derivation path: %w", - err) - } + for inputIndex := range packet.Inputs { + pIn := &packet.Inputs[inputIndex] - if len(derivationPath) < 5 { - return fmt.Errorf("invalid derivation path, expected at least "+ - "5 elements, got %d", len(derivationPath)) - } + // Check that we have an input with a derivation path that + // belongs to the root key. + derivationPath, err := findMatchingDerivationPath(rootKey, pIn) + if errors.Is(err, errNoPathFound) { + log.Infof("No matching derivation path found for "+ + "input %d, skipping", inputIndex) + continue + } + if err != nil { + return fmt.Errorf("could not find matching derivation "+ + "path: %w", err) + } - localKey, err := lnd.DeriveChildren(rootKey, derivationPath) - if err != nil { - return fmt.Errorf("could not derive local key: %w", err) - } + if len(derivationPath) < 5 { + return fmt.Errorf("invalid derivation path, expected "+ + "at least 5 elements, got %d", + len(derivationPath)) + } - if packet.Inputs[inputIndex].WitnessUtxo == nil { - return fmt.Errorf("invalid PSBT, input %d is missing witness "+ - "UTXO", inputIndex) - } - utxo := packet.Inputs[inputIndex].WitnessUtxo - - // The signing is a bit different for P2WPKH, we need to specify the - // pk script as the witness script. - var witnessScript []byte - if txscript.IsPayToWitnessPubKeyHash(utxo.PkScript) { - witnessScript = utxo.PkScript - } else { - if len(packet.Inputs[inputIndex].WitnessScript) == 0 { + localKey, err := lnd.DeriveChildren(rootKey, derivationPath) + if err != nil { + return fmt.Errorf("could not derive local key: %w", err) + } + + if pIn.WitnessUtxo == nil { return fmt.Errorf("invalid PSBT, input %d is missing "+ - "witness script", inputIndex) + "witness UTXO", inputIndex) + } + utxo := pIn.WitnessUtxo + + // The signing is a bit different for P2WPKH, we need to specify + // the pk script as the witness script. + var witnessScript []byte + if txscript.IsPayToWitnessPubKeyHash(utxo.PkScript) { + witnessScript = utxo.PkScript + } else { + if len(pIn.WitnessScript) == 0 { + return fmt.Errorf("invalid PSBT, input %d is "+ + "missing witness script", inputIndex) + } + witnessScript = pIn.WitnessScript } - witnessScript = packet.Inputs[inputIndex].WitnessScript - } - localPrivateKey, err := localKey.ECPrivKey() - if err != nil { - return fmt.Errorf("error getting private key: %w", err) - } - err = signer.AddPartialSignatureForPrivateKey( - packet, localPrivateKey, utxo, witnessScript, inputIndex, - ) - if err != nil { - return fmt.Errorf("error adding partial signature: %w", err) + localPrivateKey, err := localKey.ECPrivKey() + if err != nil { + return fmt.Errorf("error getting private key: %w", err) + } + + // Do we already have a partial signature for our key? + localPubKey := localPrivateKey.PubKey().SerializeCompressed() + haveSig := false + for _, partialSig := range pIn.PartialSigs { + if bytes.Equal(partialSig.PubKey, localPubKey) { + haveSig = true + } + } + if haveSig { + log.Infof("Already have a partial signature for input "+ + "%d and local key %x, skipping", inputIndex, + localPubKey) + continue + } + + err = signer.AddPartialSignatureForPrivateKey( + packet, localPrivateKey, utxo, witnessScript, + inputIndex, + ) + if err != nil { + return fmt.Errorf("error adding partial signature: %w", + err) + } } return nil } func findMatchingDerivationPath(rootKey *hdkeychain.ExtendedKey, - packet *psbt.Packet) ([]uint32, int, error) { + pIn *psbt.PInput) ([]uint32, error) { pubKey, err := rootKey.ECPubKey() if err != nil { - return nil, 0, fmt.Errorf("error getting public key: %w", err) + return nil, fmt.Errorf("error getting public key: %w", err) } pubKeyHash := btcutil.Hash160(pubKey.SerializeCompressed()) fingerprint := binary.LittleEndian.Uint32(pubKeyHash[:4]) - for idx, input := range packet.Inputs { - if len(input.Bip32Derivation) == 0 { - continue - } + if len(pIn.Bip32Derivation) == 0 { + return nil, errNoPathFound + } - for _, derivation := range input.Bip32Derivation { - // A special case where there is only a single - // derivation path and the master key fingerprint is not - // set, we assume we are the correct signer... This - // might not be correct, but we have no way of knowing. - if derivation.MasterKeyFingerprint == 0 && - len(input.Bip32Derivation) == 1 { + for _, derivation := range pIn.Bip32Derivation { + // A special case where there is only a single derivation path + // and the master key fingerprint is not set, we assume we are + // the correct signer... This might not be correct, but we have + // no way of knowing. + if derivation.MasterKeyFingerprint == 0 && + len(pIn.Bip32Derivation) == 1 { - return derivation.Bip32Path, idx, nil - } + return derivation.Bip32Path, nil + } - // The normal case, where a derivation path has the - // master fingerprint set. - if derivation.MasterKeyFingerprint == fingerprint { - return derivation.Bip32Path, idx, nil - } + // The normal case, where a derivation path has the master + // fingerprint set. + if derivation.MasterKeyFingerprint == fingerprint { + return derivation.Bip32Path, nil } } - return nil, 0, fmt.Errorf("no matching derivation path found") + return nil, errNoPathFound }