diff --git a/api/server/handlers/billing/customer.go b/api/server/handlers/billing/customer.go index 1f3aa4d79a..8acd7702dc 100644 --- a/api/server/handlers/billing/customer.go +++ b/api/server/handlers/billing/customer.go @@ -22,11 +22,10 @@ type CreateBillingCustomerHandler struct { // NewCreateBillingCustomerIfNotExists will create a new CreateBillingCustomerIfNotExists func NewCreateBillingCustomerIfNotExists( config *config.Config, - decoderValidator shared.RequestDecoderValidator, writer shared.ResultWriter, ) *CreateBillingCustomerHandler { return &CreateBillingCustomerHandler{ - PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer), + PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, nil, writer), } } @@ -35,23 +34,15 @@ func (c *CreateBillingCustomerHandler) ServeHTTP(w http.ResponseWriter, r *http. defer span.End() proj, _ := ctx.Value(types.ProjectScope).(*models.Project) + user, _ := r.Context().Value(types.UserScope).(*models.User) - request := &types.CreateBillingCustomerRequest{} - if ok := c.DecodeAndValidate(w, r, request); !ok { - return - } - - // There is no easy way to pass environment variables to the frontend, - // so for now pass via the backend. This is acceptable because the key is - // meant to be public - publishableKey := c.Config().BillingManager.GetPublishableKey(ctx) if proj.BillingID != "" { - c.WriteResult(w, r, publishableKey) + c.WriteResult(w, r, "") return } // Create customer in Stripe - customerID, err := c.Config().BillingManager.CreateCustomer(ctx, request.UserEmail, proj) + customerID, err := c.Config().BillingManager.CreateCustomer(ctx, user.Email, proj) if err != nil { err := telemetry.Error(ctx, span, err, "error creating billing customer") c.HandleAPIError(w, r, apierrors.NewErrInternal(fmt.Errorf("error creating billing customer: %w", err))) @@ -61,6 +52,7 @@ func (c *CreateBillingCustomerHandler) ServeHTTP(w http.ResponseWriter, r *http. telemetry.WithAttributes(span, telemetry.AttributeKV{Key: "project-id", Value: proj.ID}, telemetry.AttributeKV{Key: "customer-id", Value: proj.BillingID}, + telemetry.AttributeKV{Key: "user-email", Value: user.Email}, ) // Update the project record with the customer ID @@ -72,5 +64,40 @@ func (c *CreateBillingCustomerHandler) ServeHTTP(w http.ResponseWriter, r *http. return } + c.WriteResult(w, r, "") +} + +// GetPublishableKeyHandler will return the configured publishable key +type GetPublishableKeyHandler struct { + handlers.PorterHandlerReadWriter +} + +// NewGetPublishableKeyHandler will return the publishable key +func NewGetPublishableKeyHandler( + config *config.Config, + decoderValidator shared.RequestDecoderValidator, + writer shared.ResultWriter, +) *GetPublishableKeyHandler { + return &GetPublishableKeyHandler{ + PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer), + } +} + +func (c *GetPublishableKeyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx, span := telemetry.NewSpan(r.Context(), "get-publishable-key-endpoint") + defer span.End() + + proj, _ := ctx.Value(types.ProjectScope).(*models.Project) + + // There is no easy way to pass environment variables to the frontend, + // so for now pass via the backend. This is acceptable because the key is + // meant to be public + publishableKey := c.Config().BillingManager.GetPublishableKey(ctx) + + telemetry.WithAttributes(span, + telemetry.AttributeKV{Key: "project-id", Value: proj.ID}, + telemetry.AttributeKV{Key: "customer-id", Value: proj.BillingID}, + ) + c.WriteResult(w, r, publishableKey) } diff --git a/api/server/handlers/project/create.go b/api/server/handlers/project/create.go index e670fa4b08..0d21777cbf 100644 --- a/api/server/handlers/project/create.go +++ b/api/server/handlers/project/create.go @@ -64,6 +64,12 @@ func (p *ProjectCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) return } proj.BillingID = billingID + + telemetry.WithAttributes(span, + telemetry.AttributeKV{Key: "project-id", Value: proj.ID}, + telemetry.AttributeKV{Key: "customer-id", Value: proj.BillingID}, + telemetry.AttributeKV{Key: "user-email", Value: user.Email}, + ) } proj, _, err = CreateProjectWithUser(p.Repo().Project(), proj, user) diff --git a/api/server/router/project.go b/api/server/router/project.go index bee305f4d5..1daabed088 100644 --- a/api/server/router/project.go +++ b/api/server/router/project.go @@ -441,7 +441,6 @@ func getProjectRoutes( getOrCreateBillingCustomerHandler := billing.NewCreateBillingCustomerIfNotExists( config, - factory.GetDecoderValidator(), factory.GetResultWriter(), ) @@ -451,6 +450,33 @@ func getProjectRoutes( Router: r, }) + // GET /api/projects/{project_id}/billing/publishable_key -> project.NewGetPublishableKeyHandler + publishableKeyEndpoint := factory.NewAPIEndpoint( + &types.APIRequestMetadata{ + Verb: types.APIVerbGet, + Method: types.HTTPVerbGet, + Path: &types.Path{ + Parent: basePath, + RelativePath: relPath + "/billing/publishable_key", + }, + Scopes: []types.PermissionScope{ + types.ProjectScope, + }, + }, + ) + + publishableKeyHandler := billing.NewGetPublishableKeyHandler( + config, + factory.GetDecoderValidator(), + factory.GetResultWriter(), + ) + + routes = append(routes, &router.Route{ + Endpoint: publishableKeyEndpoint, + Handler: publishableKeyHandler, + Router: r, + }) + // GET /api/projects/{project_id}/clusters -> cluster.NewClusterListHandler listClusterEndpoint := factory.NewAPIEndpoint( &types.APIRequestMetadata{ diff --git a/api/types/billing.go b/api/types/billing.go index d6f032acfe..dbb6b47af4 100644 --- a/api/types/billing.go +++ b/api/types/billing.go @@ -1,10 +1,5 @@ package types -// CreateBillingCustomerRequest is a request for creating a new billing customer. -type CreateBillingCustomerRequest struct { - UserEmail string `json:"user_email" form:"required"` -} - // PaymentMethod is a subset of the Stripe PaymentMethod type, // with only the fields used in the dashboard type PaymentMethod = struct { diff --git a/dashboard/src/lib/hooks/useStripe.tsx b/dashboard/src/lib/hooks/useStripe.tsx index 0300ecb830..2f5d169a4a 100644 --- a/dashboard/src/lib/hooks/useStripe.tsx +++ b/dashboard/src/lib/hooks/useStripe.tsx @@ -31,7 +31,7 @@ type TCheckHasPaymentEnabled = { refetchPaymentEnabled: any; }; -type TCheckCustomerExists = { +type TGetPublishableKey = { publishableKey: string; }; @@ -150,20 +150,43 @@ export const checkIfProjectHasPayment = (): TCheckHasPaymentEnabled => { }; }; -export const checkBillingCustomerExists = (): TCheckCustomerExists => { +export const checkBillingCustomerExists = () => { + const { currentProject } = useContext(Context); + + useQuery(["checkCustomerExists", currentProject?.id], async () => { + if (!currentProject?.id || currentProject.id === -1) { + return; + } + + if (!currentProject?.billing_enabled) { + return; + } + + const res = await api.checkBillingCustomerExists( + "", + {}, + { project_id: currentProject?.id } + ); + return res.data; + }); +}; + +export const usePublishableKey = (): TGetPublishableKey => { const { user, currentProject } = useContext(Context); // Fetch list of payment methods const keyReq = useQuery( - ["checkCustomerExists", currentProject?.id], + ["getPublishableKey", currentProject?.id], async () => { if (!currentProject?.id || currentProject.id === -1) { return; } - const res = await api.checkBillingCustomerExists( + const res = await api.getPublishableKey( "", - { user_email: user?.email }, - { project_id: currentProject?.id } + {}, + { + project_id: currentProject?.id, + } ); return res.data; } diff --git a/dashboard/src/main/home/Home.tsx b/dashboard/src/main/home/Home.tsx index 49b56cb6bc..28cd412311 100644 --- a/dashboard/src/main/home/Home.tsx +++ b/dashboard/src/main/home/Home.tsx @@ -17,6 +17,7 @@ import Modal from "components/porter/Modal"; import ShowIntercomButton from "components/porter/ShowIntercomButton"; import Spacer from "components/porter/Spacer"; import Text from "components/porter/Text"; +import { checkBillingCustomerExists } from "lib/hooks/useStripe"; import api from "shared/api"; import { withAuth, type WithAuthProps } from "shared/auth/AuthorizationHoc"; @@ -293,6 +294,9 @@ const Home: React.FC = (props) => { prevCurrentCluster.current = props.currentCluster; }, [props.currentCluster]); + // Create Stripe customer if it doesn't exists already + checkBillingCustomerExists(); + const projectOverlayCall = async () => { try { const projectList = await api diff --git a/dashboard/src/main/home/modals/BillingModal.tsx b/dashboard/src/main/home/modals/BillingModal.tsx index 44e933f285..51823c2715 100644 --- a/dashboard/src/main/home/modals/BillingModal.tsx +++ b/dashboard/src/main/home/modals/BillingModal.tsx @@ -6,7 +6,7 @@ import Link from "components/porter/Link"; import Modal from "components/porter/Modal"; import Spacer from "components/porter/Spacer"; import Text from "components/porter/Text"; -import { checkBillingCustomerExists } from "lib/hooks/useStripe"; +import { usePublishableKey } from "lib/hooks/useStripe"; import PaymentSetupForm from "./PaymentSetupForm"; @@ -17,7 +17,7 @@ const BillingModal = ({ back: (value: React.SetStateAction) => void; onCreate: () => Promise; }) => { - const { publishableKey } = checkBillingCustomerExists(); + const { publishableKey } = usePublishableKey(); const stripePromise = loadStripe(publishableKey); const appearance = { diff --git a/dashboard/src/main/home/project-settings/BillingPage.tsx b/dashboard/src/main/home/project-settings/BillingPage.tsx index 1525c7b256..f8ce724026 100644 --- a/dashboard/src/main/home/project-settings/BillingPage.tsx +++ b/dashboard/src/main/home/project-settings/BillingPage.tsx @@ -10,7 +10,6 @@ import Image from "components/porter/Image"; import Spacer from "components/porter/Spacer"; import Text from "components/porter/Text"; import { - checkBillingCustomerExists, checkIfProjectHasPayment, usePaymentMethods, useSetDefaultPaymentMethod, @@ -34,7 +33,6 @@ function BillingPage(): JSX.Element { deletingIds, } = usePaymentMethods(); const { setDefaultPaymentMethod } = useSetDefaultPaymentMethod(); - checkBillingCustomerExists(); const { refetchPaymentEnabled } = checkIfProjectHasPayment(); diff --git a/dashboard/src/shared/api.tsx b/dashboard/src/shared/api.tsx index 8e5e52cb0b..2faba61e7d 100644 --- a/dashboard/src/shared/api.tsx +++ b/dashboard/src/shared/api.tsx @@ -3442,14 +3442,22 @@ const removeStackEnvGroup = baseApi< // Billing const checkBillingCustomerExists = baseApi< - { - user_email?: string; - }, + {}, { project_id?: number; } >("POST", ({ project_id }) => `/api/projects/${project_id}/billing/customer`); +const getPublishableKey = baseApi< + {}, + { + project_id?: number; + } +>( + "GET", + ({ project_id }) => `/api/projects/${project_id}/billing/publishable_key` +); + const getHasBilling = baseApi<{}, { project_id: number }>( "GET", ({ project_id }) => `/api/projects/${project_id}/billing` @@ -3847,6 +3855,7 @@ export default { // BILLING checkBillingCustomerExists, + getPublishableKey, listPaymentMethod, addPaymentMethod, setDefaultPaymentMethod, diff --git a/internal/billing/stripe.go b/internal/billing/stripe.go index 3e5d3a39bd..b667a3fcef 100644 --- a/internal/billing/stripe.go +++ b/internal/billing/stripe.go @@ -3,6 +3,7 @@ package billing import ( "context" "fmt" + "strconv" "github.com/porter-dev/porter/api/types" "github.com/porter-dev/porter/internal/models" @@ -34,9 +35,13 @@ func (s *StripeBillingManager) CreateCustomer(ctx context.Context, userEmail str if proj.BillingID == "" { // Create customer if not exists customerName := fmt.Sprintf("project_%s", proj.Name) + projectIDStr := strconv.FormatUint(uint64(proj.ID), 10) params := &stripe.CustomerParams{ Name: stripe.String(customerName), Email: stripe.String(userEmail), + Metadata: map[string]string{ + "porter_project_id": projectIDStr, + }, } // Create in Stripe