Skip to content

Commit

Permalink
Merge branch 'master' into debug-tab
Browse files Browse the repository at this point in the history
  • Loading branch information
Feroze Mohideen authored Apr 1, 2024
2 parents 971b694 + 7177585 commit 867b89e
Show file tree
Hide file tree
Showing 10 changed files with 125 additions and 32 deletions.
53 changes: 40 additions & 13 deletions api/server/handlers/billing/customer.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ type CreateBillingCustomerHandler struct {
// NewCreateBillingCustomerIfNotExists will create a new CreateBillingCustomerIfNotExists
func NewCreateBillingCustomerIfNotExists(
config *config.Config,
decoderValidator shared.RequestDecoderValidator,
writer shared.ResultWriter,
) *CreateBillingCustomerHandler {
return &CreateBillingCustomerHandler{
PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer),
PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
}
}

Expand All @@ -35,23 +34,15 @@ func (c *CreateBillingCustomerHandler) ServeHTTP(w http.ResponseWriter, r *http.
defer span.End()

proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
user, _ := r.Context().Value(types.UserScope).(*models.User)

request := &types.CreateBillingCustomerRequest{}
if ok := c.DecodeAndValidate(w, r, request); !ok {
return
}

// There is no easy way to pass environment variables to the frontend,
// so for now pass via the backend. This is acceptable because the key is
// meant to be public
publishableKey := c.Config().BillingManager.GetPublishableKey(ctx)
if proj.BillingID != "" {
c.WriteResult(w, r, publishableKey)
c.WriteResult(w, r, "")
return
}

// Create customer in Stripe
customerID, err := c.Config().BillingManager.CreateCustomer(ctx, request.UserEmail, proj)
customerID, err := c.Config().BillingManager.CreateCustomer(ctx, user.Email, proj)
if err != nil {
err := telemetry.Error(ctx, span, err, "error creating billing customer")
c.HandleAPIError(w, r, apierrors.NewErrInternal(fmt.Errorf("error creating billing customer: %w", err)))
Expand All @@ -61,6 +52,7 @@ func (c *CreateBillingCustomerHandler) ServeHTTP(w http.ResponseWriter, r *http.
telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "project-id", Value: proj.ID},
telemetry.AttributeKV{Key: "customer-id", Value: proj.BillingID},
telemetry.AttributeKV{Key: "user-email", Value: user.Email},
)

// Update the project record with the customer ID
Expand All @@ -72,5 +64,40 @@ func (c *CreateBillingCustomerHandler) ServeHTTP(w http.ResponseWriter, r *http.
return
}

c.WriteResult(w, r, "")
}

// GetPublishableKeyHandler will return the configured publishable key
type GetPublishableKeyHandler struct {
handlers.PorterHandlerReadWriter
}

// NewGetPublishableKeyHandler will return the publishable key
func NewGetPublishableKeyHandler(
config *config.Config,
decoderValidator shared.RequestDecoderValidator,
writer shared.ResultWriter,
) *GetPublishableKeyHandler {
return &GetPublishableKeyHandler{
PorterHandlerReadWriter: handlers.NewDefaultPorterHandler(config, decoderValidator, writer),
}
}

func (c *GetPublishableKeyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, span := telemetry.NewSpan(r.Context(), "get-publishable-key-endpoint")
defer span.End()

proj, _ := ctx.Value(types.ProjectScope).(*models.Project)

// There is no easy way to pass environment variables to the frontend,
// so for now pass via the backend. This is acceptable because the key is
// meant to be public
publishableKey := c.Config().BillingManager.GetPublishableKey(ctx)

telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "project-id", Value: proj.ID},
telemetry.AttributeKV{Key: "customer-id", Value: proj.BillingID},
)

c.WriteResult(w, r, publishableKey)
}
6 changes: 6 additions & 0 deletions api/server/handlers/project/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ func (p *ProjectCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return
}
proj.BillingID = billingID

telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "project-id", Value: proj.ID},
telemetry.AttributeKV{Key: "customer-id", Value: proj.BillingID},
telemetry.AttributeKV{Key: "user-email", Value: user.Email},
)
}

proj, _, err = CreateProjectWithUser(p.Repo().Project(), proj, user)
Expand Down
28 changes: 27 additions & 1 deletion api/server/router/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,6 @@ func getProjectRoutes(

getOrCreateBillingCustomerHandler := billing.NewCreateBillingCustomerIfNotExists(
config,
factory.GetDecoderValidator(),
factory.GetResultWriter(),
)

Expand All @@ -451,6 +450,33 @@ func getProjectRoutes(
Router: r,
})

// GET /api/projects/{project_id}/billing/publishable_key -> project.NewGetPublishableKeyHandler
publishableKeyEndpoint := factory.NewAPIEndpoint(
&types.APIRequestMetadata{
Verb: types.APIVerbGet,
Method: types.HTTPVerbGet,
Path: &types.Path{
Parent: basePath,
RelativePath: relPath + "/billing/publishable_key",
},
Scopes: []types.PermissionScope{
types.ProjectScope,
},
},
)

publishableKeyHandler := billing.NewGetPublishableKeyHandler(
config,
factory.GetDecoderValidator(),
factory.GetResultWriter(),
)

routes = append(routes, &router.Route{
Endpoint: publishableKeyEndpoint,
Handler: publishableKeyHandler,
Router: r,
})

// GET /api/projects/{project_id}/clusters -> cluster.NewClusterListHandler
listClusterEndpoint := factory.NewAPIEndpoint(
&types.APIRequestMetadata{
Expand Down
5 changes: 0 additions & 5 deletions api/types/billing.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
package types

// CreateBillingCustomerRequest is a request for creating a new billing customer.
type CreateBillingCustomerRequest struct {
UserEmail string `json:"user_email" form:"required"`
}

// PaymentMethod is a subset of the Stripe PaymentMethod type,
// with only the fields used in the dashboard
type PaymentMethod = struct {
Expand Down
35 changes: 29 additions & 6 deletions dashboard/src/lib/hooks/useStripe.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type TCheckHasPaymentEnabled = {
refetchPaymentEnabled: any;
};

type TCheckCustomerExists = {
type TGetPublishableKey = {
publishableKey: string;
};

Expand Down Expand Up @@ -150,20 +150,43 @@ export const checkIfProjectHasPayment = (): TCheckHasPaymentEnabled => {
};
};

export const checkBillingCustomerExists = (): TCheckCustomerExists => {
export const checkBillingCustomerExists = () => {
const { currentProject } = useContext(Context);

useQuery(["checkCustomerExists", currentProject?.id], async () => {
if (!currentProject?.id || currentProject.id === -1) {
return;
}

if (!currentProject?.billing_enabled) {
return;
}

const res = await api.checkBillingCustomerExists(
"<token>",
{},
{ project_id: currentProject?.id }
);
return res.data;
});
};

export const usePublishableKey = (): TGetPublishableKey => {
const { user, currentProject } = useContext(Context);

// Fetch list of payment methods
const keyReq = useQuery(
["checkCustomerExists", currentProject?.id],
["getPublishableKey", currentProject?.id],
async () => {
if (!currentProject?.id || currentProject.id === -1) {
return;
}
const res = await api.checkBillingCustomerExists(
const res = await api.getPublishableKey(
"<token>",
{ user_email: user?.email },
{ project_id: currentProject?.id }
{},
{
project_id: currentProject?.id,
}
);
return res.data;
}
Expand Down
4 changes: 4 additions & 0 deletions dashboard/src/main/home/Home.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Modal from "components/porter/Modal";
import ShowIntercomButton from "components/porter/ShowIntercomButton";
import Spacer from "components/porter/Spacer";
import Text from "components/porter/Text";
import { checkBillingCustomerExists } from "lib/hooks/useStripe";

import api from "shared/api";
import { withAuth, type WithAuthProps } from "shared/auth/AuthorizationHoc";
Expand Down Expand Up @@ -293,6 +294,9 @@ const Home: React.FC<Props> = (props) => {
prevCurrentCluster.current = props.currentCluster;
}, [props.currentCluster]);

// Create Stripe customer if it doesn't exists already
checkBillingCustomerExists();

const projectOverlayCall = async () => {
try {
const projectList = await api
Expand Down
4 changes: 2 additions & 2 deletions dashboard/src/main/home/modals/BillingModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Link from "components/porter/Link";
import Modal from "components/porter/Modal";
import Spacer from "components/porter/Spacer";
import Text from "components/porter/Text";
import { checkBillingCustomerExists } from "lib/hooks/useStripe";
import { usePublishableKey } from "lib/hooks/useStripe";

import PaymentSetupForm from "./PaymentSetupForm";

Expand All @@ -17,7 +17,7 @@ const BillingModal = ({
back: (value: React.SetStateAction<boolean>) => void;
onCreate: () => Promise<void>;
}) => {
const { publishableKey } = checkBillingCustomerExists();
const { publishableKey } = usePublishableKey();
const stripePromise = loadStripe(publishableKey);

const appearance = {
Expand Down
2 changes: 0 additions & 2 deletions dashboard/src/main/home/project-settings/BillingPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import Image from "components/porter/Image";
import Spacer from "components/porter/Spacer";
import Text from "components/porter/Text";
import {
checkBillingCustomerExists,
checkIfProjectHasPayment,
usePaymentMethods,
useSetDefaultPaymentMethod,
Expand All @@ -34,7 +33,6 @@ function BillingPage(): JSX.Element {
deletingIds,
} = usePaymentMethods();
const { setDefaultPaymentMethod } = useSetDefaultPaymentMethod();
checkBillingCustomerExists();

const { refetchPaymentEnabled } = checkIfProjectHasPayment();

Expand Down
15 changes: 12 additions & 3 deletions dashboard/src/shared/api.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3442,14 +3442,22 @@ const removeStackEnvGroup = baseApi<

// Billing
const checkBillingCustomerExists = baseApi<
{
user_email?: string;
},
{},
{
project_id?: number;
}
>("POST", ({ project_id }) => `/api/projects/${project_id}/billing/customer`);

const getPublishableKey = baseApi<
{},
{
project_id?: number;
}
>(
"GET",
({ project_id }) => `/api/projects/${project_id}/billing/publishable_key`
);

const getHasBilling = baseApi<{}, { project_id: number }>(
"GET",
({ project_id }) => `/api/projects/${project_id}/billing`
Expand Down Expand Up @@ -3847,6 +3855,7 @@ export default {

// BILLING
checkBillingCustomerExists,
getPublishableKey,
listPaymentMethod,
addPaymentMethod,
setDefaultPaymentMethod,
Expand Down
5 changes: 5 additions & 0 deletions internal/billing/stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package billing
import (
"context"
"fmt"
"strconv"

"github.com/porter-dev/porter/api/types"
"github.com/porter-dev/porter/internal/models"
Expand Down Expand Up @@ -34,9 +35,13 @@ func (s *StripeBillingManager) CreateCustomer(ctx context.Context, userEmail str
if proj.BillingID == "" {
// Create customer if not exists
customerName := fmt.Sprintf("project_%s", proj.Name)
projectIDStr := strconv.FormatUint(uint64(proj.ID), 10)
params := &stripe.CustomerParams{
Name: stripe.String(customerName),
Email: stripe.String(userEmail),
Metadata: map[string]string{
"porter_project_id": projectIDStr,
},
}

// Create in Stripe
Expand Down

0 comments on commit 867b89e

Please sign in to comment.