Skip to content

Commit

Permalink
Refactoring rounding and calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
samlown committed Sep 13, 2023
1 parent 95b1274 commit 5437211
Show file tree
Hide file tree
Showing 24 changed files with 2,880 additions and 1,304 deletions.
4 changes: 3 additions & 1 deletion bill/charges.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ func (m *Charge) GetTotal() num.Amount {
return m.Amount
}

func (m *Charge) removeIncludedTaxes(cat cbc.Code, accuracy uint32) *Charge {
func (m *Charge) removeIncludedTaxes(cat cbc.Code) *Charge {
accuracy := defaultTaxRemovalAccuracy
rate := m.Taxes.Get(cat)
if rate == nil || rate.Percent == nil {
return m
Expand Down Expand Up @@ -119,5 +120,6 @@ func calculateChargeSum(zero num.Amount, charges []*Charge) *num.Amount {
total = total.MatchPrecision(l.Amount)
total = total.Add(l.Amount)
}
total = total.Rescale(zero.Exp())
return &total
}
4 changes: 3 additions & 1 deletion bill/discounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ func (m *Discount) GetTotal() num.Amount {
return m.Amount.Invert()
}

func (m *Discount) removeIncludedTaxes(cat cbc.Code, accuracy uint32) *Discount {
func (m *Discount) removeIncludedTaxes(cat cbc.Code) *Discount {
accuracy := defaultTaxRemovalAccuracy
rate := m.Taxes.Get(cat)
if rate == nil || rate.Percent == nil {
return m
Expand Down Expand Up @@ -119,5 +120,6 @@ func calculateDiscountSum(zero num.Amount, discounts []*Discount) *num.Amount {
total = total.MatchPrecision(l.Amount)
total = total.Add(l.Amount)
}
total = total.Rescale(zero.Exp())
return &total
}
85 changes: 52 additions & 33 deletions bill/invoice.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ const (
ShortSchemaInvoice = "bill/invoice"
)

const (
defaultTaxRemovalAccuracy uint32 = 2
)

// Invoice represents a payment claim for goods or services supplied under
// conditions agreed between the supplier and the customer. In most cases
// the resulting document describes the actual financial commitment of goods
Expand Down Expand Up @@ -238,42 +242,45 @@ func (inv *Invoice) RemoveIncludedTaxes() (*Invoice, error) {
return nil, err
}

var i2 Invoice
for accuracy := uint32(2); accuracy <= 6; accuracy++ {

i2 = *inv
i2.Lines = make([]*Line, len(inv.Lines))
for i, l := range inv.Lines {
i2.Lines[i] = l.removeIncludedTaxes(inv.Tax.PricesInclude, accuracy)
}
i2 := *inv
i2.Totals = new(Totals)
i2.Lines = make([]*Line, len(inv.Lines))
for i, l := range inv.Lines {
i2.Lines[i] = l.removeIncludedTaxes(inv.Tax.PricesInclude)
}

if len(inv.Discounts) > 0 {
i2.Discounts = make([]*Discount, len(inv.Discounts))
for i, l := range inv.Discounts {
i2.Discounts[i] = l.removeIncludedTaxes(inv.Tax.PricesInclude, 1)
}
if len(inv.Discounts) > 0 {
i2.Discounts = make([]*Discount, len(inv.Discounts))
for i, l := range inv.Discounts {
i2.Discounts[i] = l.removeIncludedTaxes(inv.Tax.PricesInclude)
}
if len(i2.Charges) > 0 {
i2.Charges = make([]*Charge, len(inv.Charges))
for i, l := range inv.Charges {
i2.Charges[i] = l.removeIncludedTaxes(inv.Tax.PricesInclude, 1)
}
}
if len(i2.Charges) > 0 {
i2.Charges = make([]*Charge, len(inv.Charges))
for i, l := range inv.Charges {
i2.Charges[i] = l.removeIncludedTaxes(inv.Tax.PricesInclude)
}
}

tx := *i2.Tax
tx.PricesInclude = ""
i2.Tax = &tx
tx := *i2.Tax
tx.PricesInclude = ""
i2.Tax = &tx

if err := i2.Calculate(); err != nil {
return nil, err
}

// Account for any rounding errors that we just can't handle
if !inv.Totals.TotalWithTax.Equals(i2.Totals.TotalWithTax) {
rnd := inv.Totals.TotalWithTax.Subtract(i2.Totals.TotalWithTax)
fmt.Printf("A: %s B: %s C: %s\n", inv.Totals.TotalWithTax.String(), i2.Totals.TotalWithTax.String(), rnd.String())
i2.Totals.Rounding = &rnd
if err := i2.Calculate(); err != nil {
return nil, err
}

if inv.Totals.Total.String() == i2.Totals.Total.String() &&
inv.Totals.Tax.String() == i2.Totals.Tax.String() {
return &i2, nil
}
}
return nil, errors.New("insufficient precision, unable to remove included taxes")

return &i2, nil
}

// TaxRegime determines the tax regime for the invoice based on the supplier tax
Expand Down Expand Up @@ -356,7 +363,10 @@ func (inv *Invoice) calculate(r *tax.Regime, tID *tax.Identity) error {
}

// Prepare the totals we'll need with amounts based on currency
t := new(Totals)
if inv.Totals == nil {
inv.Totals = new(Totals)
}
t := inv.Totals
zero := inv.Currency.Def().Zero()
t.reset(zero)

Expand All @@ -365,7 +375,7 @@ func (inv *Invoice) calculate(r *tax.Regime, tID *tax.Identity) error {
return validation.Errors{"lines": err}
}
t.Sum = calculateLineSum(zero, inv.Lines)
t.Total = t.Sum.Rescale(zero.Exp())
t.Total = t.Sum

// Discount Lines
if err := calculateDiscounts(zero, t.Sum, inv.Discounts); err != nil {
Expand Down Expand Up @@ -408,8 +418,12 @@ func (inv *Invoice) calculate(r *tax.Regime, tID *tax.Identity) error {
Regime: r,
Zone: tID.Zone,
Date: *date,
Includes: pit,
Lines: tls,
Includes: pit,
}
if inv.Tax != nil {
tc.Calculator = inv.Tax.Calculator
tc.Rounding = inv.Tax.Rounding
}
if err := tc.Calculate(t.Taxes); err != nil {
return err
Expand All @@ -418,7 +432,7 @@ func (inv *Invoice) calculate(r *tax.Regime, tID *tax.Identity) error {
// Remove any included taxes from the total.
ct := t.Taxes.Category(pit)
if ct != nil {
ti := ct.Amount.Rescale(zero.Exp())
ti := ct.Amount
t.TaxIncluded = &ti
t.Total = t.Total.Subtract(ti)
}
Expand All @@ -427,10 +441,14 @@ func (inv *Invoice) calculate(r *tax.Regime, tID *tax.Identity) error {
if inv.Tax != nil && inv.Tax.ContainsTag(common.TagReverseCharge) {
t.Tax = zero
} else {
t.Tax = t.Taxes.Sum.Rescale(zero.Exp())
t.Tax = t.Taxes.PreciseSum()
}
t.TotalWithTax = t.Total.Add(t.Tax)
t.Payable = t.TotalWithTax
if t.Rounding != nil {
// BT-144 in EN16931
t.Payable = t.Payable.Add(*t.Rounding)
}

// Outlays
t.Outlays = calculateOutlays(zero, inv.Outlays)
Expand All @@ -451,7 +469,8 @@ func (inv *Invoice) calculate(r *tax.Regime, tID *tax.Identity) error {
inv.Payment.Terms.CalculateDues(zero, t.Payable)
}

inv.Totals = t
t.round(zero)

return nil
}

Expand Down
103 changes: 72 additions & 31 deletions bill/invoice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestRemoveIncludedTax(t *testing.T) {
assert.Equal(t, "82.6446", l0.Discounts[0].Amount.String())
assert.Equal(t, "743.8017", l0.Total.String())

assert.Equal(t, "743.8017", i2.Totals.Sum.String())
assert.Equal(t, "743.80", i2.Totals.Sum.String())
assert.Equal(t, i.Totals.Total.String(), i2.Totals.Total.String())
assert.Equal(t, i.Totals.Tax.String(), i2.Totals.Tax.String())
assert.Equal(t, i.Totals.Payable.String(), i2.Totals.Payable.String())
Expand All @@ -165,6 +165,8 @@ func TestRemoveIncludedTax2(t *testing.T) {
Code: "123TEST",
Tax: &bill.Tax{
PricesInclude: common.TaxCategoryVAT,
Calculator: tax.TotalCalculatorLine,
Rounding: tax.TotalRoundingPost,
},
Supplier: &org.Party{
TaxID: &tax.Identity{
Expand Down Expand Up @@ -218,7 +220,7 @@ func TestRemoveIncludedTax2(t *testing.T) {
assert.Equal(t, "40.7547", l0.Item.Price.String())
assert.Equal(t, "40.7547", l0.Total.String())

assert.Equal(t, "46.3447", i2.Totals.Sum.String())
assert.Equal(t, "46.34", i2.Totals.Sum.String())
assert.Equal(t, i.Totals.Total.String(), i2.Totals.Total.String())
assert.Equal(t, i.Totals.Tax.String(), i2.Totals.Tax.String())
assert.Equal(t, i.Totals.Payable.String(), i2.Totals.Payable.String())
Expand All @@ -229,6 +231,8 @@ func TestRemoveIncludedTax3(t *testing.T) {
Code: "123TEST",
Tax: &bill.Tax{
PricesInclude: common.TaxCategoryVAT,
Calculator: tax.TotalCalculatorLine,
Rounding: tax.TotalRoundingPost,
},
Supplier: &org.Party{
TaxID: &tax.Identity{
Expand Down Expand Up @@ -300,17 +304,16 @@ func TestRemoveIncludedTax3(t *testing.T) {
assert.Equal(t, "223.2642", i2.Lines[0].Total.String())
assert.Equal(t, "106.1952", i2.Lines[2].Total.String())

/*
data, _ := json.Marshal(i.Lines)
t.Logf("LINES: %v", string(data))
data, _ = json.Marshal(i.Totals)
t.Logf("TOTALS: %v", string(data))
data, _ = json.Marshal(i2.Lines)
t.Logf("Lines: %v", string(data))
data, _ = json.Marshal(i2.Totals)
t.Logf("TOTALS: %v", string(data))
*/
assert.Equal(t, "803.0066", i2.Totals.Sum.String())
data, _ := json.Marshal(i.Lines)
t.Logf("LINES: %v", string(data))
data, _ = json.Marshal(i.Totals)
t.Logf("TOTALS: %v", string(data))
data, _ = json.Marshal(i2.Lines)
t.Logf("Lines: %v", string(data))
data, _ = json.Marshal(i2.Totals)
t.Logf("TOTALS: %v", string(data))

assert.Equal(t, "803.01", i2.Totals.Sum.String())
assert.Equal(t, i.Totals.Total.String(), i2.Totals.Total.String())
assert.Equal(t, i.Totals.Tax.String(), i2.Totals.Tax.String())
assert.Equal(t, i.Totals.Payable.String(), i2.Totals.Payable.String())
Expand Down Expand Up @@ -364,7 +367,7 @@ func TestRemoveIncludedTax4(t *testing.T) {

data, _ := json.Marshal(i2.Lines)
t.Logf("TOTALS: %v", string(data))
assert.Equal(t, "4268.8209", i2.Totals.Sum.String())
assert.Equal(t, "4268.82", i2.Totals.Sum.String())
assert.Equal(t, i.Totals.Total.String(), i2.Totals.Total.String())
assert.Equal(t, i.Totals.Tax.String(), i2.Totals.Tax.String())
assert.Equal(t, i.Totals.Payable.String(), i2.Totals.Payable.String())
Expand Down Expand Up @@ -393,10 +396,23 @@ func TestRemoveIncludedTax5(t *testing.T) {

assert.Empty(t, i2.Tax.PricesInclude)
l0 := i2.Lines[0]
assert.Equal(t, "41.27359", l0.Item.Price.String())
assert.Equal(t, "41.2736", l0.Item.Price.String())

assert.Equal(t, "1320.75488", i2.Totals.Sum.String())
assert.Equal(t, i.Totals.Total.String(), i2.Totals.Total.String())
/*
data, _ := json.Marshal(i.Lines)
t.Logf("LINES: %v", string(data))
data, _ = json.Marshal(i.Totals)
t.Logf("TOTALS: %v", string(data))
data, _ = json.Marshal(i2.Lines)
t.Logf("Lines: %v", string(data))
data, _ = json.Marshal(i2.Totals)
t.Logf("TOTALS: %v", string(data))
*/

assert.Equal(t, "1320.76", i2.Totals.Sum.String())
// in this case the total is different, but that's acceptable as long
// as the payable total is the same
//assert.Equal(t, i.Totals.Total.String(), i2.Totals.Total.String())
assert.Equal(t, i.Totals.Tax.String(), i2.Totals.Tax.String())
assert.Equal(t, i.Totals.Payable.String(), i2.Totals.Payable.String())
}
Expand Down Expand Up @@ -451,13 +467,26 @@ func TestRemoveIncludedTaxQuantity(t *testing.T) {

assert.Empty(t, i2.Tax.PricesInclude)
l0 := i2.Lines[0]
assert.Equal(t, "8.26446", l0.Item.Price.String())
assert.Equal(t, "826.44600", l0.Sum.String())
assert.Equal(t, "82.64460", l0.Discounts[0].Amount.String())
assert.Equal(t, "743.80140", l0.Total.String())
assert.Equal(t, "8.2645", l0.Item.Price.String())
assert.Equal(t, "826.4500", l0.Sum.String())
assert.Equal(t, "82.6450", l0.Discounts[0].Amount.String())
assert.Equal(t, "743.8050", l0.Total.String())
assert.Equal(t, "10.00", i.Lines[0].Item.Price.String())

assert.Equal(t, i.Totals.Total.String(), i2.Totals.Total.String())
/*
data, _ := json.Marshal(i.Lines)
t.Logf("LINES: %v", string(data))
data, _ = json.Marshal(i.Totals)
t.Logf("TOTALS: %v", string(data))
data, _ = json.Marshal(i2.Lines)
t.Logf("Lines: %v", string(data))
data, _ = json.Marshal(i2.Totals)
t.Logf("TOTALS: %v", string(data))
*/

// Total changes slightly
//assert.Equal(t, i.Totals.Total.String(), i2.Totals.Total.String())
assert.Equal(t, "743.81", i2.Totals.Total.String())
assert.Equal(t, i.Totals.Tax.String(), i2.Totals.Tax.String())
assert.Equal(t, i.Totals.Payable.String(), i2.Totals.Payable.String())
}
Expand All @@ -467,6 +496,8 @@ func TestRemoveIncludedTaxDeep(t *testing.T) {
Code: "123TEST",
Tax: &bill.Tax{
PricesInclude: common.TaxCategoryVAT,
Calculator: tax.TotalCalculatorLine,
Rounding: tax.TotalRoundingPost,
},
Supplier: &org.Party{
TaxID: &tax.Identity{
Expand Down Expand Up @@ -516,19 +547,27 @@ func TestRemoveIncludedTaxDeep(t *testing.T) {
i2, err := i.RemoveIncludedTaxes()
require.NoError(t, err)

//data, _ := json.MarshalIndent(i2, "", " ")
//t.Log(string(data))

assert.Empty(t, i2.Tax.PricesInclude)
l0 := i2.Lines[0]
assert.Equal(t, "48.84906", l0.Item.Price.String()) // note extra digit!
assert.Equal(t, "17781.05784", l0.Sum.String())
assert.Equal(t, "48.8491", l0.Item.Price.String()) // note extra digit!
assert.Equal(t, "17781.0724", l0.Sum.String())
l1 := i2.Lines[1]
assert.Equal(t, "49.13208", l1.Item.Price.String())
assert.Equal(t, "49.13208", l1.Sum.String())
assert.Equal(t, "49.1321", l1.Item.Price.String())
assert.Equal(t, "49.1321", l1.Sum.String())

assert.Equal(t, "17830.19", i2.Totals.Total.String())
assert.Equal(t, i.Totals.Total.String(), i2.Totals.Total.String())
data, _ := json.Marshal(i.Lines)
t.Logf("LINES: %v", string(data))
data, _ = json.Marshal(i.Totals)
t.Logf("TOTALS: %v", string(data))
data, _ = json.Marshal(i2.Lines)
t.Logf("Lines: %v", string(data))
data, _ = json.Marshal(i2.Totals)
t.Logf("TOTALS: %v", string(data))

assert.Equal(t, "17830.20", i2.Totals.Total.String())
// assert.Equal(t, i.Totals.Total.String(), i2.Totals.Total.String())
assert.Equal(t, "17830.20", i2.Totals.Total.String())
assert.Equal(t, "-0.02", i2.Totals.Rounding.String())
assert.Equal(t, i.Totals.Tax.String(), i2.Totals.Tax.String())
assert.Equal(t, i.Totals.Payable.String(), i2.Totals.Payable.String())
}
Expand Down Expand Up @@ -713,6 +752,8 @@ func baseInvoice(t *testing.T, lines ...*bill.Line) *bill.Invoice {
Code: "123TEST",
Tax: &bill.Tax{
PricesInclude: common.TaxCategoryVAT,
Calculator: tax.TotalCalculatorLine,
Rounding: tax.TotalRoundingPost,
},
Supplier: &org.Party{
TaxID: &tax.Identity{
Expand Down
8 changes: 4 additions & 4 deletions bill/line.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ func (l *Line) calculate(r *tax.Regime, zero num.Amount) error {
return nil
}

func (l *Line) removeIncludedTaxes(cat cbc.Code, accuracy uint32) *Line {
func (l *Line) removeIncludedTaxes(cat cbc.Code) *Line {
accuracy := defaultTaxRemovalAccuracy
rate := l.Taxes.Get(cat)
if rate == nil || rate.Percent == nil {
return l
Expand All @@ -111,9 +112,8 @@ func (l *Line) removeIncludedTaxes(cat cbc.Code, accuracy uint32) *Line {
l2 := *l
l2i := *l.Item

l2.Sum = l2.Sum.Upscale(accuracy).Remove(*rate.Percent)
l2i.Price = l2.Sum.Divide(l2.Quantity)
l2.Total = l2.Total.Upscale(accuracy).Remove(*rate.Percent)
l2i.Price = l.Item.Price.Upscale(accuracy).Remove(*rate.Percent)
// assume sum and total will be calculated automatically

if len(l2.Discounts) > 0 {
rows := make([]*LineDiscount, len(l2.Discounts))
Expand Down
Loading

0 comments on commit 5437211

Please sign in to comment.