Skip to content

Commit

Permalink
Per project referral
Browse files Browse the repository at this point in the history
  • Loading branch information
MauAraujo committed Apr 30, 2024
1 parent 6e21c18 commit 75c2941
Show file tree
Hide file tree
Showing 25 changed files with 235 additions and 467 deletions.
55 changes: 55 additions & 0 deletions api/server/handlers/billing/create.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package billing

import (
"context"
"fmt"
"net/http"
"time"

"github.com/porter-dev/porter/api/server/handlers"
"github.com/porter-dev/porter/api/server/shared"
Expand All @@ -15,6 +17,15 @@ import (
"github.com/porter-dev/porter/internal/telemetry"
)

const (
// defaultRewardAmountCents is the default amount in USD cents rewarded to users
// who successfully refer a new user
defaultRewardAmountCents = 1000
// defaultPaidAmountCents is the amount paid by the user to get the credits
// grant, if set to 0 it means they are free
defaultPaidAmountCents = 0
)

// CreateBillingHandler is a handler for creating payment methods
type CreateBillingHandler struct {
handlers.PorterHandlerWriter
Expand All @@ -41,6 +52,7 @@ func (c *CreateBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
defer span.End()

proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
user, _ := ctx.Value(types.UserScope).(*models.User)

clientSecret, err := c.Config().BillingManager.StripeClient.CreatePaymentMethod(ctx, proj.BillingID)
if err != nil {
Expand All @@ -54,6 +66,15 @@ func (c *CreateBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
telemetry.AttributeKV{Key: "customer-id", Value: proj.BillingID},
)

if proj.EnableSandbox {
// Grant a reward to the project that referred this user after linking a payment method
err = c.grantRewardIfReferral(ctx, user.ID)
if err != nil {
// Only log the error in case the reward grant fails, but don't return an error to the fe
telemetry.Error(ctx, span, err, "error granting credits reward")

Check failure on line 74 in api/server/handlers/billing/create.go

View workflow job for this annotation

GitHub Actions / Go Linter

Error return value of `telemetry.Error` is not checked (errcheck)
}
}

c.WriteResult(w, r, clientSecret)
}

Expand Down Expand Up @@ -104,3 +125,37 @@ func (c *SetDefaultBillingHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ

c.WriteResult(w, r, "")
}

func (c *CreateBillingHandler) grantRewardIfReferral(ctx context.Context, referredUserID uint) (err error) {
ctx, span := telemetry.NewSpan(ctx, "grant-referral-reward")
defer span.End()

referral, err := c.Repo().Referral().GetReferralByReferredID(referredUserID)
if err != nil {
return telemetry.Error(ctx, span, err, "failed to find referral by referred id")
}

referrerProject, err := c.Repo().Project().ReadProject(referral.ProjectID)
if err != nil {
return telemetry.Error(ctx, span, err, "failed to find referrer project")
}

if referral != nil && referral.Status != models.ReferralStatusCompleted {
// Metronome requires an expiration to be passed in, so we set it to 5 years which in
// practice will mean the credits will most likely run out before expiring
expiresAt := time.Now().AddDate(5, 0, 0).Format(time.RFC3339)
reason := "Referral reward"
err := c.Config().BillingManager.MetronomeClient.CreateCreditsGrant(ctx, referrerProject.UsageID, reason, defaultRewardAmountCents, defaultPaidAmountCents, expiresAt)
if err != nil {
return telemetry.Error(ctx, span, err, "failed to grand credits reward")
}

referral.Status = models.ReferralStatusCompleted
_, err = c.Repo().Referral().UpdateReferral(referral)
if err != nil {
return telemetry.Error(ctx, span, err, "error while updating referral")
}
}

return nil
}
80 changes: 0 additions & 80 deletions api/server/handlers/billing/credits.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package billing

import (
"net/http"
"time"

"github.com/porter-dev/porter/api/server/handlers"
"github.com/porter-dev/porter/api/server/shared"
Expand All @@ -13,18 +12,6 @@ import (
"github.com/porter-dev/porter/internal/telemetry"
)

const (
// referralRewardRequirement is the number of referred users required to
// be granted a credits reward
referralRewardRequirement = 5
// defaultRewardAmountUSD is the default amount in USD rewarded to users
// who reach the reward requirement
defaultRewardAmountCents = 2000
// defaultPaidAmountUSD is the amount paid by the user to get the credits
// grant, if set to 0 it means they were free
defaultPaidAmountCents = 0
)

// ListCreditsHandler is a handler for getting available credits
type ListCreditsHandler struct {
handlers.PorterHandlerWriter
Expand Down Expand Up @@ -70,70 +57,3 @@ func (c *ListCreditsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

c.WriteResult(w, r, credits)
}

// ClaimReferralRewardHandler is a handler for granting credits
type ClaimReferralRewardHandler struct {
handlers.PorterHandlerWriter
}

// NewClaimReferralReward will create a new GrantCreditsHandler
func NewClaimReferralReward(
config *config.Config,
writer shared.ResultWriter,
) *ClaimReferralRewardHandler {
return &ClaimReferralRewardHandler{
PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
}
}

func (c *ClaimReferralRewardHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, span := telemetry.NewSpan(r.Context(), "serve-claim-credits-reward")
defer span.End()

proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
user, _ := ctx.Value(types.UserScope).(*models.User)

if !c.Config().BillingManager.MetronomeConfigLoaded || !proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) {
c.WriteResult(w, r, "")

telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "metronome-config-exists", Value: c.Config().BillingManager.MetronomeConfigLoaded},
telemetry.AttributeKV{Key: "metronome-enabled", Value: proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient)},
)
return
}

telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "metronome-enabled", Value: true},
telemetry.AttributeKV{Key: "usage-id", Value: proj.UsageID},
telemetry.AttributeKV{Key: "referral-code", Value: user.ReferralCode},
telemetry.AttributeKV{Key: "referral-reward-received", Value: user.ReferralRewardClaimed},
)

// Check if the user is eligible for the referral reward
referralCount, err := c.Repo().Referral().GetReferralCountByUserID(user.ID)
if err != nil {
c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
return
}

if !user.ReferralRewardClaimed && referralCount >= referralRewardRequirement {
// Metronome requires an expiration to be passed in, so we set it to 5 years which in
// practice will mean the credits will run out before expiring
expiresAt := time.Now().AddDate(5, 0, 0).Format(time.RFC3339)
err := c.Config().BillingManager.MetronomeClient.CreateCreditsGrant(ctx, proj.UsageID, defaultRewardAmountCents, defaultPaidAmountCents, expiresAt)
if err != nil {
c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
return
}

user.ReferralRewardClaimed = true
_, err = c.Repo().User().UpdateUser(user)
if err != nil {
c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
return
}
}

c.WriteResult(w, r, "")
}
3 changes: 3 additions & 0 deletions api/server/handlers/project/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ func (p *ProjectCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)

if p.Config().ServerConf.EnableSandbox {
step = types.StepCleanUp

// Generate referral code for porter cloud projects
proj.ReferralCode = models.NewReferralCode()
}

// create onboarding flow set to the first step. Read in env var
Expand Down
79 changes: 79 additions & 0 deletions api/server/handlers/project/referrals.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package project

import (
"net/http"

"github.com/google/uuid"
"github.com/porter-dev/porter/api/server/handlers"
"github.com/porter-dev/porter/api/server/shared"
"github.com/porter-dev/porter/api/server/shared/apierrors"
"github.com/porter-dev/porter/api/server/shared/config"
"github.com/porter-dev/porter/api/types"
"github.com/porter-dev/porter/internal/models"
"github.com/porter-dev/porter/internal/telemetry"
)

// GetProjectReferralDetailsHandler is a handler for getting a project's referral code
type GetProjectReferralDetailsHandler struct {
handlers.PorterHandlerWriter
}

// NewGetProjectReferralDetailsHandler returns an instance of GetProjectReferralDetailsHandler
func NewGetProjectReferralDetailsHandler(
config *config.Config,
writer shared.ResultWriter,
) *GetProjectReferralDetailsHandler {
return &GetProjectReferralDetailsHandler{
PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
}
}

func (c *GetProjectReferralDetailsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, span := telemetry.NewSpan(r.Context(), "serve-get-project-referral-details")
defer span.End()

proj, _ := ctx.Value(types.ProjectScope).(*models.Project)

if !c.Config().BillingManager.MetronomeConfigLoaded || !proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) ||
proj.UsageID == uuid.Nil || proj.EnableSandbox {
c.WriteResult(w, r, "")

telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "metronome-config-exists", Value: c.Config().BillingManager.MetronomeConfigLoaded},
telemetry.AttributeKV{Key: "metronome-enabled", Value: proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient)},
)
return
}

if proj.ReferralCode == "" {
telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "referral-code-exists", Value: false},
)

// Generate referral code for project if not present
proj.ReferralCode = models.NewReferralCode()
_, err := c.Repo().Project().UpdateProject(proj)
if err != nil {
err := telemetry.Error(ctx, span, err, "error updating project")
c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
return
}
}

referralCount, err := c.Repo().Referral().CountReferralsByProjectID(proj.ID, models.ReferralStatusCompleted)
if err != nil {
err := telemetry.Error(ctx, span, err, "error listing referrals by project id")
c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
return
}

referralCodeResponse := struct {
Code string `json:"code"`
ReferralCount int64 `json:"referral_count"`
}{
Code: proj.ReferralCode,
ReferralCount: referralCount,
}

c.WriteResult(w, r, referralCodeResponse)
}
4 changes: 1 addition & 3 deletions api/server/handlers/user/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ func (u *UserCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

user.Password = string(hashedPw)

// Generate referral code for user
user.ReferralCode = models.NewReferralCode()

// write the user to the db
user, err = u.Repo().User().CreateUser(user)
if err != nil {
Expand Down Expand Up @@ -106,6 +103,7 @@ func (u *UserCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
referral := &models.Referral{
Code: request.ReferredBy,
ReferredUserID: user.ID,
Status: models.ReferralStatusSignedUp,
}

_, err = u.Repo().Referral().CreateReferral(referral)
Expand Down
58 changes: 2 additions & 56 deletions api/server/handlers/user/create_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package user_test

import (
"encoding/json"
"net/http"
"testing"

Expand All @@ -10,7 +9,6 @@ import (
"github.com/porter-dev/porter/api/server/shared/apitest"
"github.com/porter-dev/porter/api/types"
"github.com/porter-dev/porter/internal/repository/test"
"github.com/stretchr/testify/assert"
)

func TestCreateUserSuccessful(t *testing.T) {
Expand Down Expand Up @@ -40,14 +38,7 @@ func TestCreateUserSuccessful(t *testing.T) {
// Use a struct that is the same as types.User but without the
// referral fields. This is because the referral code is randomly
// generated and is tested separately.
expUser := &struct {
ID uint `json:"id"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
CompanyName string `json:"company_name"`
}{
expUser := &types.CreateUserResponse{
ID: 1,
FirstName: "Mister",
LastName: "Porter",
Expand All @@ -56,14 +47,7 @@ func TestCreateUserSuccessful(t *testing.T) {
EmailVerified: false,
}

gotUser := &struct {
ID uint `json:"id"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
CompanyName string `json:"company_name"`
}{}
gotUser := &types.CreateUserResponse{}

apitest.AssertResponseExpected(t, rr, expUser, gotUser)
}
Expand Down Expand Up @@ -210,41 +194,3 @@ func TestFailingCreateSessionMethod(t *testing.T) {

apitest.AssertResponseInternalServerError(t, rr)
}

func TestCreateUserReferralCode(t *testing.T) {
req, rr := apitest.GetRequestAndRecorder(
t,
string(types.HTTPVerbPost),
"/api/users",
&types.CreateUserRequest{
FirstName: "Mister",
LastName: "Porter",
CompanyName: "Porter Technologies, Inc.",
Email: "[email protected]",
Password: "somepassword",
},
)

config := apitest.LoadConfig(t)

handler := user.NewUserCreateHandler(
config,
shared.NewDefaultRequestDecoderValidator(config.Logger, config.Alerter),
shared.NewDefaultResultWriter(config.Logger, config.Alerter),
)

handler.ServeHTTP(rr, req)
gotUser := &types.CreateUserResponse{}

// apitest.AssertResponseExpected(t, rr, expUser, gotUser)
err := json.NewDecoder(rr.Body).Decode(gotUser)
if err != nil {
t.Fatal(err)
}

// This is the default lenth of a shortuuid
desiredLenth := 22
assert.NotEmpty(t, gotUser.ReferralCode, "referral code should not be empty")
assert.Len(t, gotUser.ReferralCode, desiredLenth, "referral code should be 22 characters long")
assert.Equal(t, gotUser.ReferralRewardClaimed, false, "referral reward claimed should be false for new user")
}
3 changes: 0 additions & 3 deletions api/server/handlers/user/github_callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,6 @@ func upsertUserFromToken(config *config.Config, tok *oauth2.Token) (*models.User
GithubUserID: githubUser.GetID(),
}

// Generate referral code for user
user.ReferralCode = models.NewReferralCode()

user, err = config.Repo.User().CreateUser(user)
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit 75c2941

Please sign in to comment.