From afa4fa08b4200b96aace814ec855eb789d598832 Mon Sep 17 00:00:00 2001 From: teisnp Date: Sun, 21 Jan 2024 15:50:10 +0100 Subject: [PATCH] billing update including webhook --- config.compose.yaml | 8 +-- config.yaml | 6 +- go.mod | 2 +- go.sum | 2 + internal/billing/service.go | 36 +++++------ internal/entities/team.go | 12 ++-- internal/rest/controllers/routes.go | 2 +- internal/rest/controllers/teams/billing.go | 4 +- internal/rest/controllers/webhooks/routes.go | 3 + internal/rest/controllers/webhooks/stripe.go | 64 +++++++++++++++++--- internal/team/repository.go | 33 ++++++++++ internal/team/service.go | 12 ++++ 12 files changed, 137 insertions(+), 47 deletions(-) diff --git a/config.compose.yaml b/config.compose.yaml index 93c7298..a81b7eb 100644 --- a/config.compose.yaml +++ b/config.compose.yaml @@ -1,5 +1,5 @@ log: - level: debug + level: info format: text rest: @@ -47,7 +47,7 @@ team: application_url: http://localhost:5173 stripe: - publishable_key: "publishable_key" - secret_key: "secret_key" - webhook_secret: "webhook_secret" + publishable_key: pk_test_51NjhPuAAd26uMX + secret_key: "sk_test_51NjhPuAAd26uMXu2QkGAVDTZDLYGRQB2oxqWkZfH6j4XUOIg2HBOOKB5wRL25vOo0VkpfjnxXnP0aZ8NZqvqex3N00JZU6eD2H" + webhook_secret: "whsec_4665482442d698095ed89e386c0069beb54a1ba141a371edd4e8e0c5126d7568" domain: http://localhost:5173 \ No newline at end of file diff --git a/config.yaml b/config.yaml index 54cfa63..4faff3e 100644 --- a/config.yaml +++ b/config.yaml @@ -47,8 +47,8 @@ prober: concurrency: 1 stripe: - publishable_key: "publishable_key" - secret_key: "secret_key" - webhook_secret: "webhook_secret" + publishable_key: pk_test_51NjhPuAAd26uMXu2mDsC5CrJzCokmFCMDEiyZFGanTQAy2exlztxyuLDpg2TXC26LK8j9wqnACLAAwEyWS0AJ4r500U1rDn672 + secret_key: sk_test_51NjhPuAAd26uMXu2QkGAVDTZDLYGRQB2oxqWkZfH6j4XUOIg2HBOOKB5wRL25vOo0VkpfjnxXnP0aZ8NZqvqex3N00JZU6eD2H + webhook_secret: "whsec_4665482442d698095ed89e386c0069beb54a1ba141a371edd4e8e0c5126d7568" domain: http://localhost:5173 \ No newline at end of file diff --git a/go.mod b/go.mod index 6738faf..69eb251 100644 --- a/go.mod +++ b/go.mod @@ -111,7 +111,7 @@ require ( github.com/spf13/cast v1.5.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stripe/stripe-go/v76 v76.2.0 // indirect + github.com/stripe/stripe-go/v76 v76.13.0 // indirect github.com/subosito/gotenv v1.4.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.1 // indirect diff --git a/go.sum b/go.sum index e338246..5a43ee1 100644 --- a/go.sum +++ b/go.sum @@ -487,6 +487,8 @@ github.com/stripe/stripe-go v70.15.0+incompatible h1:hNML7M1zx8RgtepEMlxyu/FpVPr github.com/stripe/stripe-go v70.15.0+incompatible/go.mod h1:A1dQZmO/QypXmsL0T8axYZkSN/uA/T/A64pfKdBAMiY= github.com/stripe/stripe-go/v76 v76.2.0 h1:5zhef624MgfewEJ3YmZUfat1SGw+mtkR5HXtMW1J5dg= github.com/stripe/stripe-go/v76 v76.2.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4= +github.com/stripe/stripe-go/v76 v76.13.0 h1:j1tkBBA2v67yKHg9hj/0c24af8hze1vVVErJC9naT9Q= +github.com/stripe/stripe-go/v76 v76.13.0/go.mod h1:rw1MxjlAKKcZ+3FOXgTHgwiOa2ya6CPq6ykpJ0Q6Po4= github.com/subosito/gotenv v1.4.1 h1:jyEFiXpy21Wm81FBN71l9VoMMV8H8jG+qIK3GCpY6Qs= github.com/subosito/gotenv v1.4.1/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk= diff --git a/internal/billing/service.go b/internal/billing/service.go index d3a674f..499c623 100644 --- a/internal/billing/service.go +++ b/internal/billing/service.go @@ -4,6 +4,7 @@ import ( "os" "strconv" + "github.com/opsway-io/backend/internal/entities" "github.com/pkg/errors" "github.com/stripe/stripe-go/v76" portalsession "github.com/stripe/stripe-go/v76/billingportal/session" @@ -20,7 +21,7 @@ type Config struct { type Service interface { PostConfig() StripeConfig - CreateCheckoutSession(teamID uint, priceId string) (*stripe.CheckoutSession, error) + CreateCheckoutSession(team *entities.Team, priceLookupKey string) (*stripe.CheckoutSession, error) GetCheckoutSession(sessionID string) (*stripe.CheckoutSession, error) CreateCustomerPortal(sessionID string) (*stripe.BillingPortalSession, error) ConstructEvent(payload []byte, header string) (stripe.Event, error) @@ -51,39 +52,32 @@ func (s *ServiceImpl) PostConfig() StripeConfig { } } -func (s *ServiceImpl) CreateCheckoutSession(teamID uint, lookupKey string) (*stripe.CheckoutSession, error) { - // priceParams := &stripe.PriceListParams{ - // LookupKeys: stripe.StringSlice([]string{ - // lookupKey, - // }), - // } - - // i := price.List(priceParams) - // var price *stripe.Price - // for i.Next() { - // p := i.Price() - // price = p - // } - +func (s *ServiceImpl) CreateCheckoutSession(team *entities.Team, priceLookupKey string) (*stripe.CheckoutSession, error) { params := &stripe.CheckoutSessionParams{ SuccessURL: stripe.String("https://my.opsway.io/team/plan"), // ReturnURL: stripe.String("https://my.opsway.io/team/plan"), - CancelURL: stripe.String(s.Config.Domain + "/canceled.html"), - Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), - ClientReferenceID: stripe.String(strconv.FormatUint(uint64(teamID), 10)), + CancelURL: stripe.String(s.Config.Domain + "/canceled.html"), + Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)), + Customer: stripe.String(*team.StripeCustomerID), + LineItems: []*stripe.CheckoutSessionLineItemParams{ { - Price: stripe.String(lookupKey), + Price: stripe.String(priceLookupKey), Quantity: stripe.Int64(1), }, }, } + // Set teamID on session if not a previous customer + if team.StripeCustomerID == nil { + params.ClientReferenceID = stripe.String(strconv.FormatUint(uint64(team.ID), 10)) + } + return session.New(params) } func (s *ServiceImpl) GetCheckoutSession(sessionID string) (*stripe.CheckoutSession, error) { - return session.Get(sessionID, nil) + return session.Get(sessionID, &stripe.CheckoutSessionParams{}) } func (s *ServiceImpl) CreateCustomerPortal(sessionID string) (*stripe.BillingPortalSession, error) { @@ -102,5 +96,5 @@ func (s *ServiceImpl) CreateCustomerPortal(sessionID string) (*stripe.BillingPor } func (s *ServiceImpl) ConstructEvent(payload []byte, header string) (stripe.Event, error) { - return webhook.ConstructEvent(payload, header, s.Config.WebhookSecret) + return webhook.ConstructEventWithOptions(payload, header, s.Config.WebhookSecret, webhook.ConstructEventOptions{IgnoreAPIVersionMismatch: true}) } diff --git a/internal/entities/team.go b/internal/entities/team.go index 72d4c83..3c56eb2 100644 --- a/internal/entities/team.go +++ b/internal/entities/team.go @@ -29,12 +29,12 @@ var ( ) type Team struct { - ID uint - Name string `gorm:"uniqueIndex;not null"` - DisplayName *string `gorm:"index"` - PaymentPlan string `gorm:"default:FREE"` - StripeKey *string - HasAvatar bool + ID uint + Name string `gorm:"uniqueIndex;not null"` + DisplayName *string `gorm:"index"` + PaymentPlan string `gorm:"default:FREE"` + StripeCustomerID *string `gorm:"index"` + HasAvatar bool Users []User `gorm:"many2many:team_users"` Monitors []Monitor `gorm:"constraint:OnDelete:CASCADE"` diff --git a/internal/rest/controllers/routes.go b/internal/rest/controllers/routes.go index ea65673..baec650 100644 --- a/internal/rest/controllers/routes.go +++ b/internal/rest/controllers/routes.go @@ -48,7 +48,7 @@ func Register( // Webhooks - webhooks.Register(root, logger, billingService) + webhooks.Register(root, logger, billingService, teamService) // Healthz diff --git a/internal/rest/controllers/teams/billing.go b/internal/rest/controllers/teams/billing.go index 3e768b0..c7e136c 100644 --- a/internal/rest/controllers/teams/billing.go +++ b/internal/rest/controllers/teams/billing.go @@ -25,7 +25,9 @@ func (h *Handlers) PostCreateCheckoutSession(c hs.AuthenticatedContext) error { return echo.ErrBadRequest } c.Log.Info(req.TeamID) - s, err := h.BillingService.CreateCheckoutSession(req.TeamID, req.PriceLookupKey) + + team, err := h.TeamService.GetByID(c.Request().Context(), req.TeamID) + s, err := h.BillingService.CreateCheckoutSession(team, req.PriceLookupKey) if err != nil { c.Log.WithError(err).Debug("create stripe checkout session") diff --git a/internal/rest/controllers/webhooks/routes.go b/internal/rest/controllers/webhooks/routes.go index 79d7b66..e56f156 100644 --- a/internal/rest/controllers/webhooks/routes.go +++ b/internal/rest/controllers/webhooks/routes.go @@ -6,18 +6,21 @@ import ( "github.com/opsway-io/backend/internal/billing" "github.com/opsway-io/backend/internal/rest/handlers" "github.com/opsway-io/backend/internal/rest/middleware" + "github.com/opsway-io/backend/internal/team" "github.com/sirupsen/logrus" ) type Handlers struct { AuthenticationService authentication.Service BillingService billing.Service + TeamService team.Service } func Register( e *echo.Group, logger *logrus.Entry, billingService billing.Service, + teamService team.Service, ) { h := &Handlers{ BillingService: billingService, diff --git a/internal/rest/controllers/webhooks/stripe.go b/internal/rest/controllers/webhooks/stripe.go index 0333748..f832bf7 100644 --- a/internal/rest/controllers/webhooks/stripe.go +++ b/internal/rest/controllers/webhooks/stripe.go @@ -1,13 +1,28 @@ package webhooks import ( + "context" + "encoding/json" + "fmt" "io" + "net/http" + + "strconv" "github.com/labstack/echo/v4" hs "github.com/opsway-io/backend/internal/rest/handlers" + "github.com/opsway-io/backend/internal/team" + "github.com/stripe/stripe-go/v76" ) +func (h *Handlers) FulfillOrder(context context.Context, lineItems *stripe.LineItemList) { + fmt.Println(lineItems) + + // TODO: fill me in +} + func (h *Handlers) handleWebhook(c hs.StripeContext) error { + c.Log.Info("stripe webhook received") b, err := io.ReadAll(c.Request().Body) if err != nil { c.Log.WithError(err).Debug("failed to read request body for stripe event") @@ -17,26 +32,55 @@ func (h *Handlers) handleWebhook(c hs.StripeContext) error { event, err := h.BillingService.ConstructEvent(b, c.Signature) if err != nil { - c.Log.WithError(err).Debug("failed to construct stripe event") + c.Log.WithError(err).Info("failed to construct stripe event") + c.Log.Info("construct") return echo.ErrBadRequest } switch event.Type { case "checkout.session.completed": + var session stripe.CheckoutSession + err := json.Unmarshal(event.Data.Raw, &session) + if err != nil { + c.Log.WithError(err).Debug("Error parsing webhook JSON") + return echo.ErrBadRequest + } + + params := &stripe.CheckoutSessionParams{} + params.AddExpand("line_items") + + // Retrieve the session. If you require line items in the response, you may include them by expanding line_items. + sessionWithLineItems, err := h.BillingService.GetCheckoutSession(session.ID) + if err != nil { + c.Log.WithError(err).Debug("Error getting checkout session") + return echo.ErrBadRequest + } + + c.Log.Info(session.ClientReferenceID) + lineItems := sessionWithLineItems.LineItems + // Fulfill the purchase... + customerTeam, err := h.TeamService.GetByStripeID(c.Request().Context(), session.Customer.ID) + if err != nil { + if err != team.ErrNotFound { + return c.NoContent(http.StatusInternalServerError) + } + teamID, _ := strconv.ParseUint(session.ClientReferenceID, 10, 32) + customerTeam, _ := h.TeamService.GetByID(c.Request().Context(), uint(teamID)) + + h.TeamService.UpdateBilling(c.Request().Context(), customerTeam.ID, session.Customer.ID, lineItems.Data[0].Price.Product.Name) + return c.NoContent(http.StatusOK) + } + + h.TeamService.UpdateBilling(c.Request().Context(), customerTeam.ID, session.Customer.ID, lineItems.Data[0].Price.Product.Name) + + // h.FulfillOrder(lineItems) // Payment is successful and the subscription is created. // You should provision the subscription and save the customer ID to your database. - case "invoice.paid": - // Continue to provision the subscription as payments continue to be made. - // Store the status in your database and check when a user accesses your service. - // This approach helps you avoid hitting rate limits. - case "invoice.payment_failed": - // The payment failed or the customer does not have a valid payment method. - // The subscription becomes past_due. Notify your customer and send them to the - // customer portal to update their payment information. default: + c.Log.WithField("event", event.Type).Debug("Unhandled event type") // unhandled event type } - return nil + return c.NoContent(http.StatusOK) } diff --git a/internal/team/repository.go b/internal/team/repository.go index 45ed5d3..597f643 100644 --- a/internal/team/repository.go +++ b/internal/team/repository.go @@ -19,6 +19,7 @@ var ( type Repository interface { GetByID(ctx context.Context, teamId uint) (*entities.Team, error) + GetByStripeID(ctx context.Context, stripeID string) (*entities.Team, error) GetUsersByID(ctx context.Context, teamId uint, offset *int, limit *int, query *string) (*[]TeamUser, error) GetUserRole(ctx context.Context, teamID, userID uint) (*entities.TeamRole, error) GetTeamsAndRoleByUserID(ctx context.Context, userID uint) (*[]TeamAndRole, error) @@ -27,6 +28,8 @@ type Repository interface { UpdateUserRole(ctx context.Context, teamID, userID uint, role entities.TeamRole) error UpdateDisplayName(ctx context.Context, teamID uint, displayName string) error + UpdateBilling(ctx context.Context, teamID uint, customerID string, plan string) error + CreateWithOwnerUserID(ctx context.Context, team *entities.Team, ownerUserID uint) error Delete(ctx context.Context, id uint) error @@ -59,6 +62,19 @@ func (s *RepositoryImpl) GetByID(ctx context.Context, id uint) (*entities.Team, return &team, nil } +func (s *RepositoryImpl) GetByStripeID(ctx context.Context, stripeID string) (*entities.Team, error) { + var team entities.Team + if err := s.db.WithContext(ctx).Where("stripe_customer_id = ?", stripeID).First(&team).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + + return nil, err + } + + return &team, nil +} + type TeamUser struct { entities.User Role entities.TeamRole @@ -117,6 +133,23 @@ func (s *RepositoryImpl) UpdateDisplayName(ctx context.Context, teamID uint, dis return nil } +func (s *RepositoryImpl) UpdateBilling(ctx context.Context, teamID uint, customerID string, plan string) error { + result := s.db.WithContext(ctx).Model(&entities.Team{}).Where(entities.Team{ + ID: teamID, + }).Updates(entities.Team{ + StripeCustomerID: &customerID, + PaymentPlan: plan, + }) + if result.Error != nil { + return result.Error + } + if result.RowsAffected == 0 { + return ErrNotFound + } + + return nil +} + func (s *RepositoryImpl) UpdateUserRole(ctx context.Context, teamID, userID uint, role entities.TeamRole) error { err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { // If role is owner, remove all other owners diff --git a/internal/team/service.go b/internal/team/service.go index b02e127..cb293db 100644 --- a/internal/team/service.go +++ b/internal/team/service.go @@ -31,11 +31,15 @@ type Service interface { GetByID(ctx context.Context, teamId uint) (*entities.Team, error) + GetByStripeID(ctx context.Context, stripeID string) (*entities.Team, error) + GetTeamsAndRoleByUserID(ctx context.Context, userID uint) (*[]TeamAndRole, error) GetUsersByID(ctx context.Context, teamId uint, offset *int, limit *int, query *string) (*[]TeamUser, error) GetUserRole(ctx context.Context, teamID, userID uint) (*entities.TeamRole, error) UpdateUserRole(ctx context.Context, teamID, userID uint, role entities.TeamRole) error + UpdateBilling(ctx context.Context, teamID uint, customerID string, plan string) error + RemoveUser(ctx context.Context, teamID, userID uint) error UpdateDisplayName(ctx context.Context, teamID uint, displayName string) error @@ -71,6 +75,10 @@ func (s *ServiceImpl) GetByID(ctx context.Context, id uint) (*entities.Team, err return s.repository.GetByID(ctx, id) } +func (s *ServiceImpl) GetByStripeID(ctx context.Context, stripeID string) (*entities.Team, error) { + return s.repository.GetByStripeID(ctx, stripeID) +} + func (s *ServiceImpl) CreateWithOwnerUserID(ctx context.Context, team *entities.Team, ownerUserID uint) error { return s.repository.CreateWithOwnerUserID(ctx, team, ownerUserID) } @@ -147,6 +155,10 @@ func (s *ServiceImpl) UpdateUserRole(ctx context.Context, teamID, userID uint, r return s.repository.UpdateUserRole(ctx, teamID, userID, role) } +func (s *ServiceImpl) UpdateBilling(ctx context.Context, teamID uint, customerID string, plan string) error { + return s.repository.UpdateBilling(ctx, teamID, customerID, plan) +} + func (s *ServiceImpl) GetTeamsAndRoleByUserID(ctx context.Context, userID uint) (*[]TeamAndRole, error) { return s.repository.GetTeamsAndRoleByUserID(ctx, userID) }