diff --git a/api/server/handlers/billing/create.go b/api/server/handlers/billing/create.go index c784d377d3..881d8856d8 100644 --- a/api/server/handlers/billing/create.go +++ b/api/server/handlers/billing/create.go @@ -1,8 +1,10 @@ package billing import ( + "context" "fmt" "net/http" + "time" "github.com/porter-dev/porter/api/server/handlers" "github.com/porter-dev/porter/api/server/shared" @@ -41,6 +43,7 @@ func (c *CreateBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) defer span.End() proj, _ := ctx.Value(types.ProjectScope).(*models.Project) + user, _ := ctx.Value(types.UserScope).(*models.User) clientSecret, err := c.Config().BillingManager.StripeClient.CreatePaymentMethod(ctx, proj.BillingID) if err != nil { @@ -54,6 +57,16 @@ func (c *CreateBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) telemetry.AttributeKV{Key: "customer-id", Value: proj.BillingID}, ) + if proj.EnableSandbox { + // Grant a reward to the project that referred this user after linking a payment method + err = c.grantRewardIfReferral(ctx, user.ID) + if err != nil { + err := telemetry.Error(ctx, span, err, "error granting credits reward") + c.HandleAPIError(w, r, apierrors.NewErrInternal(err)) + return + } + } + c.WriteResult(w, r, clientSecret) } @@ -104,3 +117,53 @@ func (c *SetDefaultBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ c.WriteResult(w, r, "") } + +func (c *CreateBillingHandler) grantRewardIfReferral(ctx context.Context, referredUserID uint) (err error) { + ctx, span := telemetry.NewSpan(ctx, "grant-referral-reward") + defer span.End() + + referral, err := c.Repo().Referral().GetReferralByReferredID(referredUserID) + if err != nil { + return telemetry.Error(ctx, span, err, "failed to find referral by referred id") + } + + if referral == nil { + return nil + } + + referralCount, err := c.Repo().Referral().CountReferralsByProjectID(referral.ProjectID, models.ReferralStatusCompleted) + if err != nil { + return telemetry.Error(ctx, span, err, "failed to get referral count by referrer id") + } + + maxReferralRewards := c.Config().BillingManager.MetronomeClient.MaxReferralRewards + if referralCount >= maxReferralRewards { + return nil + } + + referrerProject, err := c.Repo().Project().ReadProject(referral.ProjectID) + if err != nil { + return telemetry.Error(ctx, span, err, "failed to find referrer project") + } + + if referral != nil && referral.Status != models.ReferralStatusCompleted { + // Metronome requires an expiration to be passed in, so we set it to 5 years which in + // practice will mean the credits will most likely run out before expiring + expiresAt := time.Now().AddDate(5, 0, 0).Format(time.RFC3339) + reason := "Referral reward" + rewardAmount := c.Config().BillingManager.MetronomeClient.DefaultRewardAmountCents + paidAmount := c.Config().BillingManager.MetronomeClient.DefaultPaidAmountCents + err := c.Config().BillingManager.MetronomeClient.CreateCreditsGrant(ctx, referrerProject.UsageID, reason, rewardAmount, paidAmount, expiresAt) + if err != nil { + return telemetry.Error(ctx, span, err, "failed to grand credits reward") + } + + referral.Status = models.ReferralStatusCompleted + _, err = c.Repo().Referral().UpdateReferral(referral) + if err != nil { + return telemetry.Error(ctx, span, err, "error while updating referral") + } + } + + return nil +} diff --git a/api/server/handlers/project/create.go b/api/server/handlers/project/create.go index 27ff346f56..af891b2cec 100644 --- a/api/server/handlers/project/create.go +++ b/api/server/handlers/project/create.go @@ -67,6 +67,9 @@ func (p *ProjectCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) if p.Config().ServerConf.EnableSandbox { step = types.StepCleanUp + + // Generate referral code for porter cloud projects + proj.ReferralCode = models.NewReferralCode() } // create onboarding flow set to the first step. Read in env var diff --git a/api/server/handlers/project/referrals.go b/api/server/handlers/project/referrals.go new file mode 100644 index 0000000000..80669a2f95 --- /dev/null +++ b/api/server/handlers/project/referrals.go @@ -0,0 +1,81 @@ +package project + +import ( + "net/http" + + "github.com/google/uuid" + "github.com/porter-dev/porter/api/server/handlers" + "github.com/porter-dev/porter/api/server/shared" + "github.com/porter-dev/porter/api/server/shared/apierrors" + "github.com/porter-dev/porter/api/server/shared/config" + "github.com/porter-dev/porter/api/types" + "github.com/porter-dev/porter/internal/models" + "github.com/porter-dev/porter/internal/telemetry" +) + +// GetProjectReferralDetailsHandler is a handler for getting a project's referral code +type GetProjectReferralDetailsHandler struct { + handlers.PorterHandlerWriter +} + +// NewGetProjectReferralDetailsHandler returns an instance of GetProjectReferralDetailsHandler +func NewGetProjectReferralDetailsHandler( + config *config.Config, + writer shared.ResultWriter, +) *GetProjectReferralDetailsHandler { + return &GetProjectReferralDetailsHandler{ + PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer), + } +} + +func (c *GetProjectReferralDetailsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx, span := telemetry.NewSpan(r.Context(), "serve-get-project-referral-details") + defer span.End() + + proj, _ := ctx.Value(types.ProjectScope).(*models.Project) + + if !c.Config().BillingManager.MetronomeConfigLoaded || !proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) || + proj.UsageID == uuid.Nil || !proj.EnableSandbox { + c.WriteResult(w, r, "") + + telemetry.WithAttributes(span, + telemetry.AttributeKV{Key: "metronome-config-exists", Value: c.Config().BillingManager.MetronomeConfigLoaded}, + telemetry.AttributeKV{Key: "metronome-enabled", Value: proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient)}, + ) + return + } + + if proj.ReferralCode == "" { + telemetry.WithAttributes(span, + telemetry.AttributeKV{Key: "referral-code-exists", Value: false}, + ) + + // Generate referral code for project if not present + proj.ReferralCode = models.NewReferralCode() + _, err := c.Repo().Project().UpdateProject(proj) + if err != nil { + err := telemetry.Error(ctx, span, err, "error updating project") + c.HandleAPIError(w, r, apierrors.NewErrInternal(err)) + return + } + } + + referralCount, err := c.Repo().Referral().CountReferralsByProjectID(proj.ID, models.ReferralStatusCompleted) + if err != nil { + err := telemetry.Error(ctx, span, err, "error listing referrals by project id") + c.HandleAPIError(w, r, apierrors.NewErrInternal(err)) + return + } + + referralCodeResponse := struct { + Code string `json:"code"` + ReferralCount int64 `json:"referral_count"` + MaxAllowedRewards int64 `json:"max_allowed_referrals"` + }{ + Code: proj.ReferralCode, + ReferralCount: referralCount, + MaxAllowedRewards: c.Config().BillingManager.MetronomeClient.MaxReferralRewards, + } + + c.WriteResult(w, r, referralCodeResponse) +} diff --git a/api/server/handlers/user/create.go b/api/server/handlers/user/create.go index 0b1571cce1..403a265346 100644 --- a/api/server/handlers/user/create.go +++ b/api/server/handlers/user/create.go @@ -72,14 +72,12 @@ func (u *UserCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // write the user to the db user, err = u.Repo().User().CreateUser(user) - if err != nil { u.HandleAPIError(w, r, apierrors.NewErrInternal(err)) return } err = addUserToDefaultProject(u.Config(), user) - if err != nil { u.HandleAPIError(w, r, apierrors.NewErrInternal(err)) return @@ -95,7 +93,20 @@ func (u *UserCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // non-fatal send email verification if !user.EmailVerified { err = startEmailVerification(u.Config(), w, r, user) + if err != nil { + u.HandleAPIErrorNoWrite(w, r, apierrors.NewErrInternal(err)) + } + } + + // create referral if referred by another user + if request.ReferredBy != "" { + referral := &models.Referral{ + Code: request.ReferredBy, + ReferredUserID: user.ID, + Status: models.ReferralStatusSignedUp, + } + _, err = u.Repo().Referral().CreateReferral(referral) if err != nil { u.HandleAPIErrorNoWrite(w, r, apierrors.NewErrInternal(err)) } @@ -146,7 +157,6 @@ func addUserToDefaultProject(config *config.Config, user *models.User) error { Kind: types.RoleAdmin, }, }) - if err != nil { return err } diff --git a/api/server/router/project.go b/api/server/router/project.go index d5e9e851fb..e44fed515d 100644 --- a/api/server/router/project.go +++ b/api/server/router/project.go @@ -397,6 +397,33 @@ func getProjectRoutes( Router: r, }) + // GET /api/projects/{project_id}/referrals/details -> user.NewGetUserReferralDetailsHandler + getReferralDetailsEndpoint := factory.NewAPIEndpoint( + &types.APIRequestMetadata{ + Verb: types.APIVerbGet, + Method: types.HTTPVerbGet, + Path: &types.Path{ + Parent: basePath, + RelativePath: relPath + "/referrals/details", + }, + Scopes: []types.PermissionScope{ + types.UserScope, + types.ProjectScope, + }, + }, + ) + + getReferralDetailsHandler := project.NewGetProjectReferralDetailsHandler( + config, + factory.GetResultWriter(), + ) + + routes = append(routes, &router.Route{ + Endpoint: getReferralDetailsEndpoint, + Handler: getReferralDetailsHandler, + Router: r, + }) + // POST /api/projects/{project_id}/billing/usage -> project.NewListCustomerUsageHandler listCustomerUsageEndpoint := factory.NewAPIEndpoint( &types.APIRequestMetadata{ diff --git a/api/types/billing_metronome.go b/api/types/billing_metronome.go index 502afa244a..2347b504e8 100644 --- a/api/types/billing_metronome.go +++ b/api/types/billing_metronome.go @@ -55,6 +55,19 @@ type EndCustomerPlanRequest struct { VoidStripeInvoices bool `json:"void_stripe_invoices"` } +// CreateCreditsGrantRequest is the request to create a credit grant for a customer +type CreateCreditsGrantRequest struct { + // CustomerID is the id of the customer + CustomerID uuid.UUID `json:"customer_id"` + UniquenessKey string `json:"uniqueness_key"` + GrantAmount GrantAmountID `json:"grant_amount"` + PaidAmount PaidAmount `json:"paid_amount"` + Name string `json:"name"` + ExpiresAt string `json:"expires_at"` + Priority int `json:"priority"` + Reason string `json:"reason"` +} + // ListCreditGrantsRequest is the request to list a user's credit grants. Note that only one of // CreditTypeIDs, CustomerIDs, or CreditGrantIDs must be specified. type ListCreditGrantsRequest struct { @@ -73,18 +86,6 @@ type ListCreditGrantsResponse struct { GrantedCredits float64 `json:"granted_credits"` } -// EmbeddableDashboardRequest requests an embeddable customer dashboard to Metronome -type EmbeddableDashboardRequest struct { - // CustomerID is the id of the customer - CustomerID uuid.UUID `json:"customer_id,omitempty"` - // DashboardType is the type of dashboard to retrieve - DashboardType string `json:"dashboard"` - // Options are optional dashboard specific options - Options []DashboardOption `json:"dashboard_options,omitempty"` - // ColorOverrides is an optional list of colors to override - ColorOverrides []ColorOverride `json:"color_overrides,omitempty"` -} - // ListCustomerUsageRequest is the request to list usage for a customer type ListCustomerUsageRequest struct { CustomerID uuid.UUID `json:"customer_id"` @@ -138,12 +139,33 @@ type CreditType struct { ID string `json:"id"` } -// GrantAmount represents the amount of credits granted +// GrantAmountID represents the amount of credits granted with the credit type ID +// for the create credits grant request +type GrantAmountID struct { + Amount float64 `json:"amount"` + CreditTypeID uuid.UUID `json:"credit_type_id"` +} + +// GrantAmount represents the amount of credits granted with the credit type +// for the list credit grants response type GrantAmount struct { Amount float64 `json:"amount"` CreditType CreditType `json:"credit_type"` } +// PaidAmount represents the amount paid by the customer +type PaidAmount struct { + Amount float64 `json:"amount"` + CreditTypeID uuid.UUID `json:"credit_type_id"` +} + +// PricingUnit represents the unit of the pricing (e.g. USD, MXN, CPU hours) +type PricingUnit struct { + ID uuid.UUID `json:"id"` + Name string `json:"name"` + IsCurrency bool `json:"is_currency"` +} + // Balance represents the effective balance of the grant as of the end of the customer's // current billing period. type Balance struct { @@ -166,18 +188,6 @@ type CreditGrant struct { ExpiresAt string `json:"expires_at"` } -// DashboardOption are optional dashboard specific options -type DashboardOption struct { - Key string `json:"key"` - Value string `json:"value"` -} - -// ColorOverride is an optional list of colors to override -type ColorOverride struct { - Name string `json:"name"` - Value string `json:"value"` -} - // BillingEvent represents a Metronome billing event. type BillingEvent struct { CustomerID string `json:"customer_id"` diff --git a/api/types/project.go b/api/types/project.go index 1bf9497dac..3f537cccf7 100644 --- a/api/types/project.go +++ b/api/types/project.go @@ -60,6 +60,8 @@ type Project struct { AdvancedInfraEnabled bool `json:"advanced_infra_enabled"` SandboxEnabled bool `json:"sandbox_enabled"` AdvancedRbacEnabled bool `json:"advanced_rbac_enabled"` + // ReferralCode is a unique code that can be shared to referr other users to Porter + ReferralCode string `json:"referral_code"` } // FeatureFlags is a struct that contains old feature flag representations diff --git a/api/types/referral.go b/api/types/referral.go new file mode 100644 index 0000000000..dbdb50b154 --- /dev/null +++ b/api/types/referral.go @@ -0,0 +1,12 @@ +package types + +// Referral is a struct that represents a referral in the Porter API +type Referral struct { + ID uint `json:"id"` + // Code is the referral code that is shared with the referred user + Code string `json:"referral_code"` + // ReferredUserID is the ID of the user who was referred + ReferredUserID uint `json:"referred_user_id"` + // Status is the status of the referral (pending, signed_up, etc.) + Status string `json:"status"` +} diff --git a/api/types/user.go b/api/types/user.go index ca815ed1b0..41b33b6b34 100644 --- a/api/types/user.go +++ b/api/types/user.go @@ -16,6 +16,8 @@ type CreateUserRequest struct { LastName string `json:"last_name" form:"required,max=255"` CompanyName string `json:"company_name" form:"required,max=255"` ReferralMethod string `json:"referral_method" form:"max=255"` + // ReferredBy is the referral code of the project from which this user was referred + ReferredBy string `json:"referred_by_code" form:"max=255"` } type CreateUserResponse User diff --git a/dashboard/src/lib/billing/types.tsx b/dashboard/src/lib/billing/types.tsx index 81c7a85fc2..5228fc835a 100644 --- a/dashboard/src/lib/billing/types.tsx +++ b/dashboard/src/lib/billing/types.tsx @@ -17,13 +17,15 @@ const TrialValidator = z.object({ }); export type Plan = z.infer; -export const PlanValidator = z.object({ - id: z.string(), - plan_name: z.string(), - plan_description: z.string(), - starting_on: z.string(), - trial_info: TrialValidator, -}).nullable(); +export const PlanValidator = z + .object({ + id: z.string(), + plan_name: z.string(), + plan_description: z.string(), + starting_on: z.string(), + trial_info: TrialValidator, + }) + .nullable(); export type UsageMetric = z.infer; export const UsageMetricValidator = z.object({ @@ -50,3 +52,12 @@ export const CreditGrantsValidator = z.object({ }); export const ClientSecretResponse = z.string(); + +export type ReferralDetails = z.infer; +export const ReferralDetailsValidator = z + .object({ + code: z.string(), + referral_count: z.number(), + max_allowed_referrals: z.number(), + }) + .nullable(); diff --git a/dashboard/src/lib/hooks/useStripe.tsx b/dashboard/src/lib/hooks/useStripe.tsx index 8d9e96a791..037da5d6a2 100644 --- a/dashboard/src/lib/hooks/useStripe.tsx +++ b/dashboard/src/lib/hooks/useStripe.tsx @@ -13,6 +13,8 @@ import { type PaymentMethod, type PaymentMethodList, type UsageList, + ReferralDetailsValidator, + ReferralDetails } from "lib/billing/types"; import api from "shared/api"; @@ -60,6 +62,10 @@ type TGetUsage = { usage: UsageList | null; }; +type TGetReferralDetails = { + referralDetails: ReferralDetails +}; + export const usePaymentMethods = (): TUsePaymentMethod => { const { currentProject } = useContext(Context); @@ -367,3 +373,37 @@ export const useCustomerUsage = ( usage: usageReq.data ?? null, }; }; + +export const useReferralDetails = (): TGetReferralDetails => { + const { currentProject } = useContext(Context); + + // Fetch user's referral code + const referralsReq = useQuery( + ["getReferralDetails", currentProject?.id], + async (): Promise => { + if (!currentProject?.metronome_enabled) { + return null; + } + + if (!currentProject?.id || currentProject.id === -1) { + return null; + } + + try { + const res = await api.getReferralDetails( + "", + {}, + { project_id: currentProject?.id } + ); + + const referraldetails = ReferralDetailsValidator.parse(res.data); + return referraldetails; + } catch (error) { + return null + } + }); + + return { + referralDetails: referralsReq.data ?? null, + }; +}; diff --git a/dashboard/src/main/auth/Register.tsx b/dashboard/src/main/auth/Register.tsx index dae6c25ddb..36312f279f 100644 --- a/dashboard/src/main/auth/Register.tsx +++ b/dashboard/src/main/auth/Register.tsx @@ -1,5 +1,6 @@ import React, { useContext, useEffect, useState } from "react"; import styled from "styled-components"; +import { useLocation } from "react-router-dom"; import Heading from "components/form-components/Heading"; import Button from "components/porter/Button"; @@ -35,6 +36,9 @@ const Register: React.FC = ({ authenticate }) => { const [lastName, setLastName] = useState(""); const [lastNameError, setLastNameError] = useState(false); const [companyName, setCompanyName] = useState(""); + const [referralCode, setReferralCode] = useState(""); + const [referralCodeError, setReferralCodeError] = useState(false); + const [companyNameError, setCompanyNameError] = useState(false); const [email, setEmail] = useState(""); const [emailError, setEmailError] = useState(false); @@ -71,6 +75,16 @@ const Register: React.FC = ({ authenticate }) => { { value: "Other", label: "Other" }, ]; + const { search } = useLocation() + const searchParams = new URLSearchParams(search) + const referralCodeFromUrl = searchParams.get("referral") + + useEffect(() => { + if (referralCodeFromUrl) { + setReferralCode(referralCodeFromUrl); + } + }, [referralCodeFromUrl]); // Only re-run the effect if referralCodeFromUrl changes + const handleRegister = (): void => { const isHosted = window.location.hostname === "cloud.porter.run"; if (!emailRegex.test(email)) { @@ -118,6 +132,7 @@ const Register: React.FC = ({ authenticate }) => { chosenReferralOption === "Other" ? `Other: ${referralOtherText}` : chosenReferralOption, + referred_by_code: referralCode, }, {} ) @@ -171,6 +186,7 @@ const Register: React.FC = ({ authenticate }) => { chosenReferralOption === "Other" ? `Other: ${referralOtherText}` : chosenReferralOption, + referred_by_code: referralCode, }, {} ) @@ -178,7 +194,7 @@ const Register: React.FC = ({ authenticate }) => { if (res?.data?.redirect) { window.location.href = res.data.redirect; } else { - setUser(res?.data?.id, res?.data?.email); + setUser(res?.data?.id); authenticate(); try { @@ -400,6 +416,21 @@ const Register: React.FC = ({ authenticate }) => { setValue={setChosenReferralOption} value={chosenReferralOption} /> + + { + setReferralCode(x); + setReferralCodeError(false); + }} + width="100%" + height="40px" + error={referralCodeError && ""} + /> + + {chosenReferralOption === "Other" && ( <> diff --git a/dashboard/src/main/home/project-settings/BillingPage.tsx b/dashboard/src/main/home/project-settings/BillingPage.tsx index 666090a1f2..a02763f6cc 100644 --- a/dashboard/src/main/home/project-settings/BillingPage.tsx +++ b/dashboard/src/main/home/project-settings/BillingPage.tsx @@ -1,12 +1,18 @@ import React, { useContext, useMemo, useState } from "react"; +import dayjs from "dayjs"; +import relativeTime from "dayjs/plugin/relativeTime"; import styled from "styled-components"; +import CopyToClipboard from "components/CopyToClipboard"; import Loading from "components/Loading"; +import Banner from "components/porter/Banner"; import Button from "components/porter/Button"; import Container from "components/porter/Container"; import Fieldset from "components/porter/Fieldset"; import Icon from "components/porter/Icon"; import Image from "components/porter/Image"; +import Link from "components/porter/Link"; +import Modal from "components/porter/Modal"; import Spacer from "components/porter/Spacer"; import Text from "components/porter/Text"; import { @@ -15,10 +21,9 @@ import { useCustomerUsage, usePaymentMethods, usePorterCredits, + useReferralDetails, useSetDefaultPaymentMethod, } from "lib/hooks/useStripe"; -import dayjs from "dayjs"; -import relativeTime from "dayjs/plugin/relativeTime"; import { Context } from "shared/Context"; import cardIcon from "assets/credit-card.svg"; @@ -31,8 +36,10 @@ import Bars from "./Bars"; dayjs.extend(relativeTime); function BillingPage(): JSX.Element { + const { referralDetails } = useReferralDetails(); const { setCurrentOverlay } = useContext(Context); const [shouldCreate, setShouldCreate] = useState(false); + const [showReferralModal, setShowReferralModal] = useState(false); const { currentProject } = useContext(Context); const { creditGrants } = usePorterCredits(); @@ -93,6 +100,16 @@ function BillingPage(): JSX.Element { await refetchPaymentEnabled({ throwOnError: false, cancelRefetch: false }); }; + const isTrialExpired = (timestamp: string): boolean => { + if (timestamp === "") { + return true; + } + const timestampDate = dayjs(timestamp); + return timestampDate.isBefore(dayjs(new Date())); + }; + + const trialExpired = plan && isTrialExpired(plan.trial_info.ending_before); + if (shouldCreate) { return ( + {plan?.trial_info !== undefined && + plan.trial_info.ending_before !== "" && + !trialExpired && ( + <> + + Your free trial is ending{" "} + {dayjs().to(dayjs(plan.trial_info.ending_before))}. + + + + )} + {currentProject?.metronome_enabled && currentProject?.sandbox_enabled && ( + <> + Credit balance + + + View the amount of Porter credits you have remaining to spend on + resources in this project. + + + + + + + {creditGrants && creditGrants.remaining_credits > 0 + ? `$${formatCredits(creditGrants.remaining_credits)}` + : "$ 0.00"} + + + + + Earn additional free credits by{" "} + { + setShowReferralModal(true); + }} + > + referring users to Porter + + . + + + + )} Payment methods @@ -179,116 +241,98 @@ function BillingPage(): JSX.Element { onClick={() => { setShouldCreate(true); }} + alt > add - Add Payment Method + Add payment method - {currentProject?.metronome_enabled && ( -
- - {currentProject?.sandbox_enabled && ( -
- Porter credit grants - + {currentProject?.metronome_enabled && plan && plan.plan_name !== "" ? ( + <> + Current usage + + + View the current usage of this billing period. + + + {usage?.length && + usage.length > 0 && + usage[0].usage_metrics.length > 0 ? ( + + + + + + + + + + ) : ( +
- View the amount of Porter credits you have available to spend on - resources within this project. + No usage data available for this billing period. - - - - - - - {creditGrants && - creditGrants.remaining_credits > 0 - ? `$${formatCredits( - creditGrants.remaining_credits - )}/$${formatCredits(creditGrants.granted_credits)}` - : "$ 0.00"} - - - -
+ )} - -
- Plan Details - - - View the details of the current billing plan of this project. - - - - {plan && plan.plan_name !== "" ? ( -
- Active Plan - -
- - - {plan.plan_name} - - - {plan.trial_info !== undefined && - plan.trial_info.ending_before !== "" ? ( - - Free trial ends{" "} - {dayjs().to(dayjs(plan.trial_info.ending_before))} - - ) : ( - Started on {readableDate(plan.starting_on)} - )} - - -
- - Current Usage - - - View the current usage of this billing period. - - - {usage?.length && - usage.length > 0 && - usage[0].usage_metrics.length > 0 ? ( - - - - - - - - - - ) : ( -
- - No usage data available for this billing period. - -
- )} - -
- ) : ( - This project does not have an active billing plan. - )} -
-
+ + + ) : ( + This project does not have an active billing plan. + )} + {showReferralModal && ( + { + setShowReferralModal(false); + }} + > + Refer users to Porter + + + Earn $10 in free credits for each user you refer to Porter. Referred + users need to connect a payment method for credits to be added to + your account. + + + + + Referral code:{" "} + {currentProject?.referral_code ? ( + {currentProject.referral_code} + ) : ( + "n/a" + )} + + + + Copy referral link + + + + + You have referred{" "} + {referralDetails ? referralDetails.referral_count : "?"}/{referralDetails?.max_allowed_referrals} users. + + )} ); @@ -296,6 +340,25 @@ function BillingPage(): JSX.Element { export default BillingPage; +const CopyButton = styled.div` + cursor: pointer; + background: #ffffff11; + padding: 5px; + border-radius: 5px; + font-size: 13px; +`; + +const Code = styled.span` + font-style: italic; +`; + +const ReferralCode = styled.div` + background: linear-gradient(60deg, #4b366d 0%, #6475b9 100%); + padding: 10px 15px; + border-radius: 10px; + width: fit-content; +`; + const Flex = styled.div` display: flex; flex-wrap: wrap; @@ -308,8 +371,8 @@ const BarWrapper = styled.div` `; const I = styled.i` - font-size: 18px; - margin-right: 10px; + font-size: 16px; + margin-right: 8px; `; const DeleteButton = styled.div` diff --git a/dashboard/src/main/home/project-settings/InviteList.tsx b/dashboard/src/main/home/project-settings/InviteList.tsx index 4561c89746..80d6511fb7 100644 --- a/dashboard/src/main/home/project-settings/InviteList.tsx +++ b/dashboard/src/main/home/project-settings/InviteList.tsx @@ -684,8 +684,8 @@ const InvitePage: React.FunctionComponent = ({}) => { export default InvitePage; const I = styled.i` - margin-right: 10px; - font-size: 18px; + margin-right: 8px; + font-size: 16px; `; const Flex = styled.div` diff --git a/dashboard/src/main/home/project-settings/ReferralsPage.tsx b/dashboard/src/main/home/project-settings/ReferralsPage.tsx new file mode 100644 index 0000000000..2f2a19cef9 --- /dev/null +++ b/dashboard/src/main/home/project-settings/ReferralsPage.tsx @@ -0,0 +1,36 @@ +import React from "react"; + +import Link from "components/porter/Link"; +import Spacer from "components/porter/Spacer"; +import Text from "components/porter/Text"; +import { useReferralDetails } from "lib/hooks/useStripe"; + +function ReferralsPage(): JSX.Element { + const { referralDetails } = useReferralDetails(); + const baseUrl = window.location.origin; + + return ( + <> + Referrals + + Refer people to Porter to earn credits. + + {referralDetails !== null && ( + <> + Your referral link is + { const { project_id, cluster_id, stack_name, page } = pathParams; - return `/api/projects/${project_id}/clusters/${cluster_id}/applications/${stack_name}/events?page=${ - page || 1 - }`; + return `/api/projects/${project_id}/clusters/${cluster_id}/applications/${stack_name}/events?page=${page || 1 + }`; }); const createEnvironment = baseApi< @@ -876,11 +875,9 @@ const detectBuildpack = baseApi< branch: string; } >("GET", (pathParams) => { - return `/api/projects/${pathParams.project_id}/gitrepos/${ - pathParams.git_repo_id - }/repos/${pathParams.kind}/${pathParams.owner}/${ - pathParams.name - }/${encodeURIComponent(pathParams.branch)}/buildpack/detect`; + return `/api/projects/${pathParams.project_id}/gitrepos/${pathParams.git_repo_id + }/repos/${pathParams.kind}/${pathParams.owner}/${pathParams.name + }/${encodeURIComponent(pathParams.branch)}/buildpack/detect`; }); const detectGitlabBuildpack = baseApi< @@ -911,11 +908,9 @@ const getBranchContents = baseApi< branch: string; } >("GET", (pathParams) => { - return `/api/projects/${pathParams.project_id}/gitrepos/${ - pathParams.git_repo_id - }/repos/${pathParams.kind}/${pathParams.owner}/${ - pathParams.name - }/${encodeURIComponent(pathParams.branch)}/contents`; + return `/api/projects/${pathParams.project_id}/gitrepos/${pathParams.git_repo_id + }/repos/${pathParams.kind}/${pathParams.owner}/${pathParams.name + }/${encodeURIComponent(pathParams.branch)}/contents`; }); const getProcfileContents = baseApi< @@ -931,11 +926,9 @@ const getProcfileContents = baseApi< branch: string; } >("GET", (pathParams) => { - return `/api/projects/${pathParams.project_id}/gitrepos/${ - pathParams.git_repo_id - }/repos/${pathParams.kind}/${pathParams.owner}/${ - pathParams.name - }/${encodeURIComponent(pathParams.branch)}/procfile`; + return `/api/projects/${pathParams.project_id}/gitrepos/${pathParams.git_repo_id + }/repos/${pathParams.kind}/${pathParams.owner}/${pathParams.name + }/${encodeURIComponent(pathParams.branch)}/procfile`; }); const getPorterYamlContents = baseApi< @@ -951,11 +944,9 @@ const getPorterYamlContents = baseApi< branch: string; } >("GET", (pathParams) => { - return `/api/projects/${pathParams.project_id}/gitrepos/${ - pathParams.git_repo_id - }/repos/${pathParams.kind}/${pathParams.owner}/${ - pathParams.name - }/${encodeURIComponent(pathParams.branch)}/porteryaml`; + return `/api/projects/${pathParams.project_id}/gitrepos/${pathParams.git_repo_id + }/repos/${pathParams.kind}/${pathParams.owner}/${pathParams.name + }/${encodeURIComponent(pathParams.branch)}/porteryaml`; }); const parsePorterYaml = baseApi< @@ -1015,32 +1006,30 @@ const getBranchHead = baseApi< branch: string; } >("GET", (pathParams) => { - return `/api/projects/${pathParams.project_id}/gitrepos/${ - pathParams.git_repo_id - }/repos/${pathParams.kind}/${pathParams.owner}/${ - pathParams.name - }/${encodeURIComponent(pathParams.branch)}/head`; + return `/api/projects/${pathParams.project_id}/gitrepos/${pathParams.git_repo_id + }/repos/${pathParams.kind}/${pathParams.owner}/${pathParams.name + }/${encodeURIComponent(pathParams.branch)}/head`; }); const createApp = baseApi< | { - name: string; - deployment_target_id: string; - type: "github"; - git_repo_id: number; - git_branch: string; - git_repo_name: string; - porter_yaml_path: string; - } + name: string; + deployment_target_id: string; + type: "github"; + git_repo_id: number; + git_branch: string; + git_repo_name: string; + porter_yaml_path: string; + } | { - name: string; - deployment_target_id: string; - type: "docker-registry"; - image: { - repository: string; - tag: string; - }; - }, + name: string; + deployment_target_id: string; + type: "docker-registry"; + image: { + repository: string; + tag: string; + }; + }, { project_id: number; cluster_id: number; @@ -2167,6 +2156,7 @@ const registerUser = baseApi<{ last_name: string; company_name: string; referral_method?: string; + referred_by_code?: string; }>("POST", "/api/users"); const rollbackChart = baseApi< @@ -2308,11 +2298,9 @@ const getEnvGroup = baseApi< version?: number; } >("GET", (pathParams) => { - return `/api/projects/${pathParams.id}/clusters/${ - pathParams.cluster_id - }/namespaces/${pathParams.namespace}/envgroup?name=${pathParams.name}${ - pathParams.version ? "&version=" + pathParams.version : "" - }`; + return `/api/projects/${pathParams.id}/clusters/${pathParams.cluster_id + }/namespaces/${pathParams.namespace}/envgroup?name=${pathParams.name}${pathParams.version ? "&version=" + pathParams.version : "" + }`; }); const getConfigMap = baseApi< @@ -3589,7 +3577,18 @@ const deletePaymentMethod = baseApi< `/api/projects/${project_id}/billing/payment_method/${payment_method_id}` ); -const getGithubStatus = baseApi<{}, {}>("GET", ({}) => `/api/status/github`); +const getReferralDetails = baseApi< + {}, + { + project_id?: number; + } +>( + "GET", + ({ project_id }) => + `/api/projects/${project_id}/referrals/details` +); + +const getGithubStatus = baseApi<{}, {}>("GET", ({ }) => `/api/status/github`); const createSecretAndOpenGitHubPullRequest = baseApi< { @@ -3982,6 +3981,7 @@ export default { addPaymentMethod, setDefaultPaymentMethod, deletePaymentMethod, + getReferralDetails, // STATUS getGithubStatus, diff --git a/dashboard/src/shared/types.tsx b/dashboard/src/shared/types.tsx index 30a80a2ade..f623a998a0 100644 --- a/dashboard/src/shared/types.tsx +++ b/dashboard/src/shared/types.tsx @@ -289,15 +289,15 @@ export type FormElement = { export type RepoType = { FullName: string; } & ( - | { + | { Kind: "github"; GHRepoID: number; } - | { + | { Kind: "gitlab"; GitIntegrationId: number; } -); + ); export type FileType = { path: string; @@ -344,6 +344,7 @@ export type ProjectType = { user_id: number; project_id: number; }>; + referral_code: string; }; export type ChoiceType = { @@ -379,15 +380,15 @@ export type ActionConfigType = { image_repo_uri: string; dockerfile_path?: string; } & ( - | { + | { kind: "gitlab"; gitlab_integration_id: number; } - | { + | { kind: "github"; git_repo_id: number; } -); + ); export type GithubActionConfigType = ActionConfigType & { kind: "github"; diff --git a/go.mod b/go.mod index 4b720a31ba..9775cd4cb7 100644 --- a/go.mod +++ b/go.mod @@ -82,6 +82,7 @@ require ( github.com/honeycombio/otel-config-go v1.11.0 github.com/launchdarkly/go-sdk-common/v3 v3.0.1 github.com/launchdarkly/go-server-sdk/v6 v6.1.0 + github.com/lithammer/shortuuid/v4 v4.0.0 github.com/matryer/is v1.4.0 github.com/nats-io/nats.go v1.24.0 github.com/open-policy-agent/opa v0.44.0 diff --git a/go.sum b/go.sum index 68a0c415cd..cc59b871df 100644 --- a/go.sum +++ b/go.sum @@ -1249,6 +1249,8 @@ github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de h1:9TO3cAIGXtEhnIaL+V+BEER86oLrvS+kWobKpbJuye0= github.com/liggitt/tabwriter v0.0.0-20181228230101-89fcab3d43de/go.mod h1:zAbeS9B/r2mtpb6U+EI2rYA5OAXxsYw6wTamcNW+zcE= github.com/linuxkit/virtsock v0.0.0-20201010232012-f8cee7dfc7a3/go.mod h1:3r6x7q95whyfWQpmGZTu3gk3v2YkMi05HEzl7Tf7YEo= +github.com/lithammer/shortuuid/v4 v4.0.0 h1:QRbbVkfgNippHOS8PXDkti4NaWeyYfcBTHtw7k08o4c= +github.com/lithammer/shortuuid/v4 v4.0.0/go.mod h1:Zs8puNcrvf2rV9rTH51ZLLcj7ZXqQI3lv67aw4KiB1Y= github.com/logrusorgru/aurora v0.0.0-20181002194514-a7b3b318ed4e/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= @@ -1552,8 +1554,6 @@ github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/polyfloyd/go-errorlint v0.0.0-20210722154253-910bb7978349/go.mod h1:wi9BfjxjF/bwiZ701TzmfKu6UKC357IOAtNr0Td0Lvw= -github.com/porter-dev/api-contracts v0.2.156 h1:IooB1l6tl+jiGecj2IzYsPoIJxnePaJntDpKSwJBxgc= -github.com/porter-dev/api-contracts v0.2.156/go.mod h1:VV5BzXd02ZdbWIPLVP+PX3GKawJSGQnxorVT2sUZALU= github.com/porter-dev/api-contracts v0.2.157 h1:xjC1q4/8ZUl5QLVyCkTfIiMZn+k8h0c9AO9nrCFcZ1Y= github.com/porter-dev/api-contracts v0.2.157/go.mod h1:VV5BzXd02ZdbWIPLVP+PX3GKawJSGQnxorVT2sUZALU= github.com/porter-dev/switchboard v0.0.3 h1:dBuYkiVLa5Ce7059d6qTe9a1C2XEORFEanhbtV92R+M= diff --git a/internal/billing/metronome.go b/internal/billing/metronome.go index 0e9a786332..7b2356c92e 100644 --- a/internal/billing/metronome.go +++ b/internal/billing/metronome.go @@ -16,10 +16,13 @@ import ( ) const ( - metronomeBaseUrl = "https://api.metronome.com/v1/" - defaultCollectionMethod = "charge_automatically" - defaultMaxRetries = 10 - porterStandardTrialDays = 15 + metronomeBaseUrl = "https://api.metronome.com/v1/" + defaultCollectionMethod = "charge_automatically" + defaultMaxRetries = 10 + porterStandardTrialDays = 15 + defaultRewardAmountCents = 1000 + defaultPaidAmountCents = 0 + maxReferralRewards = 10 ) // MetronomeClient is the client used to call the Metronome API @@ -28,6 +31,15 @@ type MetronomeClient struct { billableMetrics []types.BillableMetric PorterCloudPlanID uuid.UUID PorterStandardPlanID uuid.UUID + + // DefaultRewardAmountCents is the default amount in USD cents rewarded to users + // who successfully refer a new user + DefaultRewardAmountCents float64 + // DefaultPaidAmountCents is the amount paid by the user to get the credits + // grant, if set to 0 it means they are free + DefaultPaidAmountCents float64 + // MaxReferralRewards is the maximum number of referral rewards a user can receive + MaxReferralRewards int64 } // NewMetronomeClient returns a new Metronome client @@ -43,9 +55,12 @@ func NewMetronomeClient(metronomeApiKey string, porterCloudPlanID string, porter } return MetronomeClient{ - ApiKey: metronomeApiKey, - PorterCloudPlanID: porterCloudPlanUUID, - PorterStandardPlanID: porterStandardPlanUUID, + ApiKey: metronomeApiKey, + PorterCloudPlanID: porterCloudPlanUUID, + PorterStandardPlanID: porterStandardPlanUUID, + DefaultRewardAmountCents: defaultRewardAmountCents, + DefaultPaidAmountCents: defaultPaidAmountCents, + MaxReferralRewards: maxReferralRewards, }, nil } @@ -242,6 +257,47 @@ func (m MetronomeClient) ListCustomerCredits(ctx context.Context, customerID uui return response, nil } +// CreateCreditsGrant will create a new credit grant for the customer with the specified amount +func (m MetronomeClient) CreateCreditsGrant(ctx context.Context, customerID uuid.UUID, reason string, grantAmount float64, paidAmount float64, expiresAt string) (err error) { + ctx, span := telemetry.NewSpan(ctx, "create-credits-grant") + defer span.End() + + if customerID == uuid.Nil { + return telemetry.Error(ctx, span, err, "customer id empty") + } + + path := "credits/createGrant" + creditTypeID, err := m.getCreditTypeID(ctx, "USD (cents)") + if err != nil { + return telemetry.Error(ctx, span, err, "failed to get credit type id") + } + + req := types.CreateCreditsGrantRequest{ + CustomerID: customerID, + UniquenessKey: uuid.NewString(), + GrantAmount: types.GrantAmountID{ + Amount: grantAmount, + CreditTypeID: creditTypeID, + }, + PaidAmount: types.PaidAmount{ + Amount: paidAmount, + CreditTypeID: creditTypeID, + }, + Name: "Porter Credits", + Reason: reason, + ExpiresAt: expiresAt, + Priority: 1, + } + + statusCode, err := m.do(http.MethodPost, path, req, nil) + if err != nil && statusCode != http.StatusConflict { + // a conflict response indicates the grant already exists + return telemetry.Error(ctx, span, err, "failed to create credits grant") + } + + return nil +} + // ListCustomerUsage will return the aggregated usage for a customer func (m MetronomeClient) ListCustomerUsage(ctx context.Context, customerID uuid.UUID, startingOn string, endingBefore string, windowsSize string, currentPeriod bool) (usage []types.Usage, err error) { ctx, span := telemetry.NewSpan(ctx, "list-customer-usage") @@ -359,6 +415,30 @@ func (m MetronomeClient) listBillableMetricIDs(ctx context.Context, customerID u return result.Data, nil } +func (m MetronomeClient) getCreditTypeID(ctx context.Context, currencyCode string) (creditTypeID uuid.UUID, err error) { + ctx, span := telemetry.NewSpan(ctx, "get-credit-type-id") + defer span.End() + + path := "/credit-types/list" + + var result struct { + Data []types.PricingUnit `json:"data"` + } + + _, err = m.do(http.MethodGet, path, nil, &result) + if err != nil { + return creditTypeID, telemetry.Error(ctx, span, err, "failed to retrieve billable metrics from metronome") + } + + for _, pricingUnit := range result.Data { + if pricingUnit.Name == currencyCode { + return pricingUnit.ID, nil + } + } + + return creditTypeID, telemetry.Error(ctx, span, fmt.Errorf("credit type not found for currency code %s", currencyCode), "failed to find credit type") +} + func (m MetronomeClient) do(method string, path string, body interface{}, data interface{}) (statusCode int, err error) { client := http.Client{} endpoint, err := url.JoinPath(metronomeBaseUrl, path) diff --git a/internal/models/project.go b/internal/models/project.go index 058ac3e86a..4447d35c14 100644 --- a/internal/models/project.go +++ b/internal/models/project.go @@ -226,6 +226,12 @@ type Project struct { EnableReprovision bool `gorm:"default:false"` AdvancedInfraEnabled bool `gorm:"default:false"` AdvancedRbacEnabled bool `gorm:"default:false"` + + // ReferralCode is a unique code that can be shared to referr other users to Porter + ReferralCode string + + // Referrals is a list of users that have been referred by this project's code + Referrals []Referral `json:"referrals"` } // GetFeatureFlag calls launchdarkly for the specified flag @@ -332,6 +338,7 @@ func (p *Project) ToProjectType(launchDarklyClient *features.Client) types.Proje AdvancedInfraEnabled: p.GetFeatureFlag(AdvancedInfraEnabled, launchDarklyClient), SandboxEnabled: p.EnableSandbox, AdvancedRbacEnabled: p.GetFeatureFlag(AdvancedRbacEnabled, launchDarklyClient), + ReferralCode: p.ReferralCode, } } diff --git a/internal/models/referral.go b/internal/models/referral.go new file mode 100644 index 0000000000..0dc3ce1147 --- /dev/null +++ b/internal/models/referral.go @@ -0,0 +1,42 @@ +package models + +import ( + "github.com/lithammer/shortuuid/v4" + "github.com/porter-dev/porter/api/types" + "gorm.io/gorm" +) + +const ( + // ReferralStatusSignedUp is the status of a referral where the referred user has signed up + ReferralStatusSignedUp = "signed_up" + // ReferralStatusCompleted is the status of a referral where the referred user has linked a credit card + ReferralStatusCompleted = "completed" +) + +// Referral type that extends gorm.Model +type Referral struct { + gorm.Model + + // Code is the referral code that is shared with the referred user + Code string + // ProjectID is the ID of the project that was used to refer a new user + ProjectID uint + // ReferredUserID is the ID of the user who was referred + ReferredUserID uint + // Status is the status of the referral (pending, signed_up, etc.) + Status string +} + +// NewReferralCode generates a new referral code +func NewReferralCode() string { + return shortuuid.New() +} + +// ToReferralType generates an external types.Referral to be shared over REST +func (r *Referral) ToReferralType() *types.Referral { + return &types.Referral{ + ID: r.ID, + ReferredUserID: r.ReferredUserID, + Status: r.Status, + } +} diff --git a/internal/repository/gorm/migrate.go b/internal/repository/gorm/migrate.go index 5034caa71e..cab4abef8a 100644 --- a/internal/repository/gorm/migrate.go +++ b/internal/repository/gorm/migrate.go @@ -88,5 +88,6 @@ func AutoMigrate(db *gorm.DB, debug bool) error { &models.Ipam{}, &models.AppEventWebhooks{}, &models.ClusterHealthReport{}, + &models.Referral{}, ) } diff --git a/internal/repository/gorm/referrals.go b/internal/repository/gorm/referrals.go new file mode 100644 index 0000000000..5165fd17b3 --- /dev/null +++ b/internal/repository/gorm/referrals.go @@ -0,0 +1,75 @@ +package gorm + +import ( + "errors" + + "github.com/porter-dev/porter/internal/models" + "github.com/porter-dev/porter/internal/repository" + "gorm.io/gorm" +) + +// ReferralRepository uses gorm.DB for querying the database +type ReferralRepository struct { + db *gorm.DB +} + +// NewReferralRepository returns a ReferralRepository which uses +// gorm.DB for querying the database +func NewReferralRepository(db *gorm.DB) repository.ReferralRepository { + return &ReferralRepository{db} +} + +// CreateReferral creates a new referral in the database +func (repo *ReferralRepository) CreateReferral(referral *models.Referral) (*models.Referral, error) { + project := &models.Project{} + + if err := repo.db.Where("referral_code = ?", referral.Code).First(&project).Error; err != nil { + return nil, err + } + + assoc := repo.db.Model(&project).Association("Referrals") + + if assoc.Error != nil { + return nil, assoc.Error + } + + if err := assoc.Append(referral); err != nil { + return nil, err + } + + return referral, nil +} + +// CountReferralsByProjectID returns the number of referrals a user has made +func (repo *ReferralRepository) CountReferralsByProjectID(projectID uint, status string) (int64, error) { + var count int64 + + if err := repo.db.Model(&models.Referral{}).Where("project_id = ? AND status = ?", projectID, status).Count(&count).Error; err != nil { + return 0, err + } + + return count, nil +} + +// GetReferralByReferredID returns a referral by the referred user's ID +func (repo *ReferralRepository) GetReferralByReferredID(referredID uint) (*models.Referral, error) { + referral := &models.Referral{} + err := repo.db.Where("referred_user_id = ?", referredID).First(&referral).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, nil + } + + if err != nil { + return &models.Referral{}, err + } + return referral, nil +} + +// UpdateReferral updates a referral in the database +func (repo *ReferralRepository) UpdateReferral(referral *models.Referral) (*models.Referral, error) { + if err := repo.db.Save(referral).Error; err != nil { + return nil, err + } + + return referral, nil +} diff --git a/internal/repository/gorm/repository.go b/internal/repository/gorm/repository.go index 555d14d4b7..724bcdcfb1 100644 --- a/internal/repository/gorm/repository.go +++ b/internal/repository/gorm/repository.go @@ -62,6 +62,7 @@ type GormRepository struct { datastore repository.DatastoreRepository appInstance repository.AppInstanceRepository ipam repository.IpamRepository + referral repository.ReferralRepository } func (t *GormRepository) User() repository.UserRepository { @@ -293,6 +294,11 @@ func (t *GormRepository) Ipam() repository.IpamRepository { return t.ipam } +// Referral returns the ReferralRepository interface implemented by gorm +func (t *GormRepository) Referral() repository.ReferralRepository { + return t.referral +} + // NewRepository returns a Repository which persists users in memory // and accepts a parameter that can trigger read/write errors func NewRepository(db *gorm.DB, key *[32]byte, storageBackend credentials.CredentialStorage) repository.Repository { @@ -352,5 +358,6 @@ func NewRepository(db *gorm.DB, key *[32]byte, storageBackend credentials.Creden appInstance: NewAppInstanceRepository(db), ipam: NewIpamRepository(db), appEventWebhook: NewAppEventWebhookRepository(db), + referral: NewReferralRepository(db), } } diff --git a/internal/repository/referral.go b/internal/repository/referral.go new file mode 100644 index 0000000000..4b6ff73502 --- /dev/null +++ b/internal/repository/referral.go @@ -0,0 +1,13 @@ +package repository + +import ( + "github.com/porter-dev/porter/internal/models" +) + +// ReferralRepository represents the set of queries on the Referral model +type ReferralRepository interface { + CreateReferral(referral *models.Referral) (*models.Referral, error) + GetReferralByReferredID(referredID uint) (*models.Referral, error) + CountReferralsByProjectID(projectID uint, status string) (int64, error) + UpdateReferral(referral *models.Referral) (*models.Referral, error) +} diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 2bb2df1f98..a803a41f3b 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -55,4 +55,5 @@ type Repository interface { GithubWebhook() GithubWebhookRepository Datastore() DatastoreRepository AppInstance() AppInstanceRepository + Referral() ReferralRepository } diff --git a/internal/repository/test/referrral.go b/internal/repository/test/referrral.go new file mode 100644 index 0000000000..6de00ad0db --- /dev/null +++ b/internal/repository/test/referrral.go @@ -0,0 +1,32 @@ +package test + +import ( + "errors" + + "github.com/porter-dev/porter/internal/models" + "github.com/porter-dev/porter/internal/repository" +) + +// ReferralRepository represents the set of queries on the Referral model +type ReferralRepository struct{} + +// NewAppInstanceRepository returns the test AppInstanceRepository +func NewReferralRepository() repository.ReferralRepository { + return &ReferralRepository{} +} + +func (repo *ReferralRepository) CreateReferral(referral *models.Referral) (*models.Referral, error) { + return referral, errors.New("cannot read database") +} + +func (repo *ReferralRepository) CountReferralsByProjectID(projectID uint, status string) (int64, error) { + return 0, errors.New("cannot read database") +} + +func (repo *ReferralRepository) GetReferralByReferredID(referredID uint) (*models.Referral, error) { + return &models.Referral{}, errors.New("cannot read database") +} + +func (repo *ReferralRepository) UpdateReferral(referral *models.Referral) (*models.Referral, error) { + return referral, errors.New("cannot read database") +} diff --git a/internal/repository/test/repository.go b/internal/repository/test/repository.go index 26905a0364..1927adf491 100644 --- a/internal/repository/test/repository.go +++ b/internal/repository/test/repository.go @@ -59,6 +59,7 @@ type TestRepository struct { githubWebhook repository.GithubWebhookRepository datastore repository.DatastoreRepository appInstance repository.AppInstanceRepository + referral repository.ReferralRepository } func (t *TestRepository) User() repository.UserRepository { @@ -283,6 +284,11 @@ func (t *TestRepository) AppInstance() repository.AppInstanceRepository { return t.appInstance } +// Referral returns a test Referral +func (t *TestRepository) Referral() repository.ReferralRepository { + return t.referral +} + // NewRepository returns a Repository which persists users in memory // and accepts a parameter that can trigger read/write errors func NewRepository(canQuery bool, failingMethods ...string) repository.Repository { @@ -341,5 +347,6 @@ func NewRepository(canQuery bool, failingMethods ...string) repository.Repositor githubWebhook: NewGithubWebhookRepository(), datastore: NewDatastoreRepository(), appInstance: NewAppInstanceRepository(), + referral: NewReferralRepository(), } }