Skip to content

Commit

Permalink
sqlite: set siafund address, fix siafund balance update
Browse files Browse the repository at this point in the history
  • Loading branch information
n8maninger committed Feb 28, 2024
1 parent e2d6437 commit b9279f9
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
34 changes: 19 additions & 15 deletions persist/sqlite/consensus.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,13 +394,13 @@ func (ut *updateTx) AddSiafundElements(elements []types.SiafundElement, index ty
}
defer insertStmt.Close()

balanceChanges := make(map[types.Address]uint64)
balanceChanges := make(map[int64]uint64)
for _, se := range elements {
addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0))
if err != nil {
return fmt.Errorf("failed to query address: %w", err)
} else if _, ok := balanceChanges[se.SiafundOutput.Address]; !ok {
balanceChanges[se.SiafundOutput.Address] = addrRef.Balance.Siafunds
} else if _, ok := balanceChanges[addrRef.ID]; !ok {
balanceChanges[addrRef.ID] = addrRef.Balance.Siafunds
}

var dummy types.Hash256
Expand All @@ -410,21 +410,21 @@ func (ut *updateTx) AddSiafundElements(elements []types.SiafundElement, index ty
} else if err != nil {
return fmt.Errorf("failed to execute statement: %w", err)
}
balanceChanges[se.SiafundOutput.Address] += se.SiafundOutput.Value
balanceChanges[addrRef.ID] += se.SiafundOutput.Value
}

if len(balanceChanges) == 0 {
return nil
}

updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE sia_address=$2`)
updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE id=$2`)
if err != nil {
return fmt.Errorf("failed to prepare update balance statement: %w", err)
}
defer updateAddressBalanceStmt.Close()

for addr, balance := range balanceChanges {
res, err := updateAddressBalanceStmt.Exec(balance, encode(addr))
for addrID, balance := range balanceChanges {
res, err := updateAddressBalanceStmt.Exec(balance, addrID)
if err != nil {
return fmt.Errorf("failed to update balance: %w", err)
} else if n, err := res.RowsAffected(); err != nil {
Expand All @@ -449,13 +449,13 @@ func (ut *updateTx) RemoveSiafundElements(elements []types.SiafundElement, index
}
defer stmt.Close()

balanceChanges := make(map[types.Address]uint64)
balanceChanges := make(map[int64]uint64)
for _, se := range elements {
addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0))
if err != nil {
return fmt.Errorf("failed to query address: %w", err)
} else if _, ok := balanceChanges[se.SiafundOutput.Address]; !ok {
balanceChanges[se.SiafundOutput.Address] = addrRef.Balance.Siafunds
} else if _, ok := balanceChanges[addrRef.ID]; !ok {
balanceChanges[addrRef.ID] = addrRef.Balance.Siafunds
}

var dummy types.Hash256
Expand All @@ -464,20 +464,24 @@ func (ut *updateTx) RemoveSiafundElements(elements []types.SiafundElement, index
return fmt.Errorf("failed to delete element %q: %w", se.ID, err)
}

if balanceChanges[se.SiafundOutput.Address] < se.SiafundOutput.Value {
if balanceChanges[addrRef.ID] < se.SiafundOutput.Value {
panic("siafund balance cannot be negative")
}
balanceChanges[se.SiafundOutput.Address] -= se.SiafundOutput.Value
balanceChanges[addrRef.ID] -= se.SiafundOutput.Value
}

if len(balanceChanges) == 0 {
return nil
}

updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE sia_address=$2`)
updateAddressBalanceStmt, err := ut.tx.Prepare(`UPDATE sia_addresses SET siafund_balance=$1 WHERE id=$2`)
if err != nil {
return fmt.Errorf("failed to prepare update balance statement: %w", err)
}
defer updateAddressBalanceStmt.Close()

for addr, balance := range balanceChanges {
res, err := updateAddressBalanceStmt.Exec(balance, encode(addr))
for addrID, balance := range balanceChanges {
res, err := updateAddressBalanceStmt.Exec(balance, addrID)
if err != nil {
return fmt.Errorf("failed to update balance: %w", err)
} else if n, err := res.RowsAffected(); err != nil {
Expand Down
30 changes: 16 additions & 14 deletions persist/sqlite/consensus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import (
"go.uber.org/zap/zaptest"
)

func testV1Network() (*consensus.Network, types.Block) {
func testV1Network(siafundAddr types.Address) (*consensus.Network, types.Block) {
// use a modified version of Zen
n, genesisBlock := chain.TestnetZen()
genesisBlock.Transactions[0].SiafundOutputs[0].Address = siafundAddr
n.InitialTarget = types.BlockID{0xFF}
n.HardforkDevAddr.Height = 1
n.HardforkTax.Height = 1
Expand All @@ -28,9 +29,10 @@ func testV1Network() (*consensus.Network, types.Block) {
return n, genesisBlock
}

func testV2Network() (*consensus.Network, types.Block) {
func testV2Network(siafundAddr types.Address) (*consensus.Network, types.Block) {
// use a modified version of Zen
n, genesisBlock := chain.TestnetZen()
genesisBlock.Transactions[0].SiafundOutputs[0].Address = siafundAddr
n.InitialTarget = types.BlockID{0xFF}
n.HardforkDevAddr.Height = 1
n.HardforkTax.Height = 1
Expand Down Expand Up @@ -75,6 +77,9 @@ func mineV2Block(state consensus.State, txns []types.V2Transaction, minerAddr ty
}

func TestReorg(t *testing.T) {
pk := types.GeneratePrivateKey()
addr := types.StandardUnlockHash(pk.PublicKey())

log := zaptest.NewLogger(t)
dir := t.TempDir()
db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3"))
Expand All @@ -89,7 +94,7 @@ func TestReorg(t *testing.T) {
}
defer bdb.Close()

network, genesisBlock := testV1Network()
network, genesisBlock := testV1Network(addr)

store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock)
if err != nil {
Expand All @@ -103,9 +108,6 @@ func TestReorg(t *testing.T) {
t.Fatal(err)
}

pk := types.GeneratePrivateKey()
addr := types.StandardUnlockHash(pk.PublicKey())

w, err := db.AddWallet(wallet.Wallet{Name: "test"})
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -280,6 +282,9 @@ func TestReorg(t *testing.T) {
}

func TestEphemeralBalance(t *testing.T) {
pk := types.GeneratePrivateKey()
addr := types.StandardUnlockHash(pk.PublicKey())

log := zaptest.NewLogger(t)
dir := t.TempDir()
db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3"))
Expand All @@ -294,7 +299,7 @@ func TestEphemeralBalance(t *testing.T) {
}
defer bdb.Close()

network, genesisBlock := testV1Network()
network, genesisBlock := testV1Network(addr)

store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock)
if err != nil {
Expand All @@ -308,9 +313,6 @@ func TestEphemeralBalance(t *testing.T) {
t.Fatal(err)
}

pk := types.GeneratePrivateKey()
addr := types.StandardUnlockHash(pk.PublicKey())

w, err := db.AddWallet(wallet.Wallet{Name: "test"})
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -475,6 +477,9 @@ func TestEphemeralBalance(t *testing.T) {
}

func TestV2(t *testing.T) {
pk := types.GeneratePrivateKey()
addr := types.StandardUnlockHash(pk.PublicKey())

log := zaptest.NewLogger(t)
dir := t.TempDir()
db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3"))
Expand All @@ -489,7 +494,7 @@ func TestV2(t *testing.T) {
}
defer bdb.Close()

network, genesisBlock := testV2Network()
network, genesisBlock := testV2Network(addr)

store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock)
if err != nil {
Expand All @@ -503,9 +508,6 @@ func TestV2(t *testing.T) {
t.Fatal(err)
}

pk := types.GeneratePrivateKey()
addr := types.StandardUnlockHash(pk.PublicKey())

w, err := db.AddWallet(wallet.Wallet{Name: "test"})
if err != nil {
t.Fatal(err)
Expand Down
8 changes: 4 additions & 4 deletions persist/sqlite/wallet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ func TestWalletAddresses(t *testing.T) {
}

func TestResubscribe(t *testing.T) {
pk := types.GeneratePrivateKey()
addr := types.StandardUnlockHash(pk.PublicKey())

log := zaptest.NewLogger(t)
dir := t.TempDir()
db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3"))
Expand All @@ -123,7 +126,7 @@ func TestResubscribe(t *testing.T) {
}
defer bdb.Close()

network, genesisBlock := testV1Network()
network, genesisBlock := testV1Network(types.VoidAddress)

store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock)
if err != nil {
Expand All @@ -137,9 +140,6 @@ func TestResubscribe(t *testing.T) {
t.Fatal(err)
}

pk := types.GeneratePrivateKey()
addr := types.StandardUnlockHash(pk.PublicKey())

w, err := db.AddWallet(wallet.Wallet{Name: "test"})
if err != nil {
t.Fatal(err)
Expand Down

0 comments on commit b9279f9

Please sign in to comment.