Skip to content

Commit

Permalink
Merge pull request #605 from lightninglabs/accounts
Browse files Browse the repository at this point in the history
accounts: add label and AccountInfo RPC
  • Loading branch information
guggero authored Jul 26, 2023
2 parents b05e640 + 6c147e9 commit 88bb10d
Show file tree
Hide file tree
Showing 24 changed files with 1,186 additions and 265 deletions.
6 changes: 3 additions & 3 deletions accounts/checkers.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params,
if len(invoice) > 0 {
payReq, err := zpay32.Decode(invoice, chainParams)
if err != nil {
return fmt.Errorf("error decoding pay req: %v", err)
return fmt.Errorf("error decoding pay req: %w", err)
}

if payReq.MilliSat != nil && *payReq.MilliSat > sendAmt {
Expand All @@ -546,7 +546,7 @@ func checkSend(ctx context.Context, chainParams *chaincfg.Params,

err = service.CheckBalance(acct.ID, sendAmt)
if err != nil {
return fmt.Errorf("error validating account balance: %v", err)
return fmt.Errorf("error validating account balance: %w", err)
}

return nil
Expand Down Expand Up @@ -609,7 +609,7 @@ func checkSendToRoute(ctx context.Context, service Service,

err = service.CheckBalance(acct.ID, sendAmt)
if err != nil {
return fmt.Errorf("error validating account balance: %v", err)
return fmt.Errorf("error validating account balance: %w", err)
}

return nil
Expand Down
10 changes: 5 additions & 5 deletions accounts/checkers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ type mockService struct {
acctBalanceMsat lnwire.MilliSatoshi

trackedInvoices map[lntypes.Hash]AccountID
trackedPayments map[lntypes.Hash]*PaymentEntry
trackedPayments AccountPayments
}

func newMockService() *mockService {
return &mockService{
acctBalanceMsat: 0,
trackedInvoices: make(map[lntypes.Hash]AccountID),
trackedPayments: make(map[lntypes.Hash]*PaymentEntry),
trackedPayments: make(AccountPayments),
}
}

Expand All @@ -68,7 +68,7 @@ func (m *mockService) AssociateInvoice(id AccountID, hash lntypes.Hash) error {
return nil
}

func (m *mockService) TrackPayment(id AccountID, hash lntypes.Hash,
func (m *mockService) TrackPayment(_ AccountID, hash lntypes.Hash,
amt lnwire.MilliSatoshi) error {

m.trackedPayments[hash] = &PaymentEntry{
Expand Down Expand Up @@ -403,8 +403,8 @@ func TestAccountCheckers(t *testing.T) {
acct := &OffChainBalanceAccount{
ID: testID,
Type: TypeInitialBalance,
Invoices: make(map[lntypes.Hash]struct{}),
Payments: make(map[lntypes.Hash]*PaymentEntry),
Invoices: make(AccountInvoices),
Payments: make(AccountPayments),
}
ctx := AddToContext(
context.Background(), KeyAccount, acct,
Expand Down
2 changes: 1 addition & 1 deletion accounts/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func parseRPCMessage(msg *lnrpc.RPCMessage) (proto.Message, error) {
// No, it's a normal message.
parsedMsg, err := mid.ParseProtobuf(msg.TypeName, msg.Serialized)
if err != nil {
return nil, fmt.Errorf("error parsing proto of type %v: %v",
return nil, fmt.Errorf("error parsing proto of type %v: %w",
msg.TypeName, err)
}

Expand Down
20 changes: 15 additions & 5 deletions accounts/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func ParseAccountID(idStr string) (*AccountID, error) {

idBytes, err := hex.DecodeString(idStr)
if err != nil {
return nil, fmt.Errorf("error decoding account ID: %v", err)
return nil, fmt.Errorf("error decoding account ID: %w", err)
}

var id AccountID
Expand All @@ -67,6 +67,12 @@ type PaymentEntry struct {
FullAmount lnwire.MilliSatoshi
}

// AccountInvoices is the set of invoices that are associated with an account.
type AccountInvoices map[lntypes.Hash]struct{}

// AccountPayments is the set of payments that are associated with an account.
type AccountPayments map[lntypes.Hash]*PaymentEntry

// OffChainBalanceAccount holds all information that is needed to keep track of
// a user's off-chain account balance. This balance can only be spent by paying
// invoices.
Expand Down Expand Up @@ -99,11 +105,15 @@ type OffChainBalanceAccount struct {

// Invoices is a list of all invoices that are associated with the
// account.
Invoices map[lntypes.Hash]struct{}
Invoices AccountInvoices

// Payments is a list of all payments that are associated with the
// account and the last status we were aware of.
Payments map[lntypes.Hash]*PaymentEntry
Payments AccountPayments

// Label is an optional label that can be set for the account. If it is
// not empty then it must be unique.
Label string
}

// HasExpired returns true if the account has an expiration date set and that
Expand Down Expand Up @@ -180,8 +190,8 @@ var (
type Store interface {
// NewAccount creates a new OffChainBalanceAccount with the given
// balance and a randomly chosen ID.
NewAccount(balance lnwire.MilliSatoshi,
expirationDate time.Time) (*OffChainBalanceAccount, error)
NewAccount(balance lnwire.MilliSatoshi, expirationDate time.Time,
label string) (*OffChainBalanceAccount, error)

// UpdateAccount writes an account to the database, overwriting the
// existing one if it exists.
Expand Down
101 changes: 80 additions & 21 deletions accounts/rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func (s *RPCServer) CreateAccount(ctx context.Context,
req *litrpc.CreateAccountRequest) (*litrpc.CreateAccountResponse,
error) {

log.Infof("[createaccount] balance=%d, expiration=%d",
req.AccountBalance, req.ExpirationDate)
log.Infof("[createaccount] label=%v, balance=%d, expiration=%d",
req.Label, req.AccountBalance, req.ExpirationDate)

var (
balanceMsat lnwire.MilliSatoshi
Expand All @@ -70,9 +70,11 @@ func (s *RPCServer) CreateAccount(ctx context.Context,
balanceMsat = lnwire.NewMSatFromSatoshis(balance)

// Create the actual account in the macaroon account store.
account, err := s.service.NewAccount(balanceMsat, expirationDate)
account, err := s.service.NewAccount(
balanceMsat, expirationDate, req.Label,
)
if err != nil {
return nil, fmt.Errorf("unable to create account: %v", err)
return nil, fmt.Errorf("unable to create account: %w", err)
}

var rootKeyIdSuffix [4]byte
Expand All @@ -91,12 +93,12 @@ func (s *RPCServer) CreateAccount(ctx context.Context,
}},
})
if err != nil {
return nil, fmt.Errorf("error baking account macaroon: %v", err)
return nil, fmt.Errorf("error baking account macaroon: %w", err)
}

macBytes, err := hex.DecodeString(macHex)
if err != nil {
return nil, fmt.Errorf("error decoding account macaroon: %v",
return nil, fmt.Errorf("error decoding account macaroon: %w",
err)
}

Expand All @@ -110,16 +112,13 @@ func (s *RPCServer) CreateAccount(ctx context.Context,
func (s *RPCServer) UpdateAccount(_ context.Context,
req *litrpc.UpdateAccountRequest) (*litrpc.Account, error) {

log.Infof("[updateaccount] id=%s, balance=%d, expiration=%d", req.Id,
req.AccountBalance, req.ExpirationDate)
log.Infof("[updateaccount] id=%s, label=%v, balance=%d, expiration=%d",
req.Id, req.Label, req.AccountBalance, req.ExpirationDate)

// Account ID is always a hex string, convert it to our account ID type.
var accountID AccountID
decoded, err := hex.DecodeString(req.Id)
accountID, err := s.findAccount(req.Id, req.Label)
if err != nil {
return nil, fmt.Errorf("error decoding account ID: %v", err)
return nil, err
}
copy(accountID[:], decoded)

// Ask the service to update the account.
account, err := s.service.UpdateAccount(
Expand All @@ -142,7 +141,7 @@ func (s *RPCServer) ListAccounts(context.Context,
// Retrieve all accounts from the macaroon account store.
accts, err := s.service.Accounts()
if err != nil {
return nil, fmt.Errorf("unable to list accounts: %v", err)
return nil, fmt.Errorf("unable to list accounts: %w", err)
}

// Map the response into the proper response type and return it.
Expand All @@ -158,30 +157,89 @@ func (s *RPCServer) ListAccounts(context.Context,
}, nil
}

// AccountInfo returns the account with the given ID or label.
func (s *RPCServer) AccountInfo(_ context.Context,
req *litrpc.AccountInfoRequest) (*litrpc.Account, error) {

log.Infof("[accountinfo] id=%v, label=%v", req.Id, req.Label)

accountID, err := s.findAccount(req.Id, req.Label)
if err != nil {
return nil, err
}

dbAccount, err := s.service.Account(accountID)
if err != nil {
return nil, fmt.Errorf("error retrieving account: %w", err)
}

return marshalAccount(dbAccount), nil
}

// RemoveAccount removes the given account from the account database.
func (s *RPCServer) RemoveAccount(_ context.Context,
req *litrpc.RemoveAccountRequest) (*litrpc.RemoveAccountResponse,
error) {

log.Infof("[removeaccount] id=%v", req.Id)
log.Infof("[removeaccount] id=%v, label=%v", req.Id, req.Label)

// Account ID is always a hex string, convert it to our account ID type.
var accountID AccountID
decoded, err := hex.DecodeString(req.Id)
accountID, err := s.findAccount(req.Id, req.Label)
if err != nil {
return nil, fmt.Errorf("error decoding account ID: %v", err)
return nil, err
}
copy(accountID[:], decoded)

// Now remove the account.
err = s.service.RemoveAccount(accountID)
if err != nil {
return nil, fmt.Errorf("error removing account: %v", err)
return nil, fmt.Errorf("error removing account: %w", err)
}

return &litrpc.RemoveAccountResponse{}, nil
}

// findAccount finds an account by its ID or label.
func (s *RPCServer) findAccount(id string, label string) (AccountID, error) {
switch {
case id != "" && label != "":
return AccountID{}, fmt.Errorf("either account ID or label " +
"must be specified, not both")

case id != "":
// Account ID is always a hex string, convert it to our account
// ID type.
var accountID AccountID
decoded, err := hex.DecodeString(id)
if err != nil {
return AccountID{}, fmt.Errorf("error decoding "+
"account ID: %w", err)
}
copy(accountID[:], decoded)

return accountID, nil

case label != "":
// We need to find the account by its label.
accounts, err := s.service.Accounts()
if err != nil {
return AccountID{}, fmt.Errorf("unable to list "+
"accounts: %w", err)
}

for _, acct := range accounts {
if acct.Label == label {
return acct.ID, nil
}
}

return AccountID{}, fmt.Errorf("unable to find account "+
"with label '%s'", label)

default:
return AccountID{}, fmt.Errorf("either account ID or label " +
"must be specified")
}
}

// marshalAccount converts an account into its RPC counterpart.
func marshalAccount(acct *OffChainBalanceAccount) *litrpc.Account {
rpcAccount := &litrpc.Account{
Expand All @@ -196,6 +254,7 @@ func marshalAccount(acct *OffChainBalanceAccount) *litrpc.Account {
Payments: make(
[]*litrpc.AccountPayment, 0, len(acct.Payments),
),
Label: acct.Label,
}

for hash := range acct.Invoices {
Expand Down
Loading

0 comments on commit 88bb10d

Please sign in to comment.