Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Referral program #4590

Merged
merged 19 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions api/server/handlers/billing/credits.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package billing

import (
"net/http"
"time"

"github.com/porter-dev/porter/api/server/handlers"
"github.com/porter-dev/porter/api/server/shared"
"github.com/porter-dev/porter/api/server/shared/apierrors"
"github.com/porter-dev/porter/api/server/shared/config"
"github.com/porter-dev/porter/api/types"
"github.com/porter-dev/porter/internal/models"
"github.com/porter-dev/porter/internal/telemetry"
)

const (
// referralRewardRequirement is the number of referred users required to
// be granted a credits reward
referralRewardRequirement = 5
// defaultRewardAmountUSD is the default amount in USD rewarded to users
// who reach the reward requirement
defaultRewardAmountCents = 2000
// defaultPaidAmountUSD is the amount paid by the user to get the credits
// grant, if set to 0 it means they were free
defaultPaidAmountCents = 0
)

// ListCreditsHandler is a handler for getting available credits
type ListCreditsHandler struct {
handlers.PorterHandlerWriter
}

// NewListCreditsHandler will create a new ListCreditsHandler
func NewListCreditsHandler(
config *config.Config,
writer shared.ResultWriter,
) *ListCreditsHandler {
return &ListCreditsHandler{
PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
}
}

func (c *ListCreditsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, span := telemetry.NewSpan(r.Context(), "serve-list-credits")
defer span.End()

proj, _ := ctx.Value(types.ProjectScope).(*models.Project)

if !c.Config().BillingManager.MetronomeConfigLoaded || !proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) {
c.WriteResult(w, r, "")

telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "metronome-config-exists", Value: c.Config().BillingManager.MetronomeConfigLoaded},
telemetry.AttributeKV{Key: "metronome-enabled", Value: proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient)},
)
return
}

credits, err := c.Config().BillingManager.MetronomeClient.ListCustomerCredits(ctx, proj.UsageID)
if err != nil {
err := telemetry.Error(ctx, span, err, "error listing credits")
c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
return
}

telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "metronome-enabled", Value: true},
telemetry.AttributeKV{Key: "usage-id", Value: proj.UsageID},
)

c.WriteResult(w, r, credits)
}

// ClaimReferralRewardHandler is a handler for granting credits
type ClaimReferralRewardHandler struct {
handlers.PorterHandlerWriter
}

// NewClaimReferralReward will create a new GrantCreditsHandler
func NewClaimReferralReward(
config *config.Config,
writer shared.ResultWriter,
) *ClaimReferralRewardHandler {
return &ClaimReferralRewardHandler{
PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
}
}

func (c *ClaimReferralRewardHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, span := telemetry.NewSpan(r.Context(), "serve-claim-credits-reward")
defer span.End()

proj, _ := ctx.Value(types.ProjectScope).(*models.Project)
user, _ := ctx.Value(types.UserScope).(*models.User)

if !c.Config().BillingManager.MetronomeConfigLoaded || !proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) {
c.WriteResult(w, r, "")

telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "metronome-config-exists", Value: c.Config().BillingManager.MetronomeConfigLoaded},
telemetry.AttributeKV{Key: "metronome-enabled", Value: proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient)},
)
return
}

telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "metronome-enabled", Value: true},
telemetry.AttributeKV{Key: "usage-id", Value: proj.UsageID},
telemetry.AttributeKV{Key: "referral-code", Value: user.ReferralCode},
telemetry.AttributeKV{Key: "referral-reward-received", Value: user.ReferralRewardClaimed},
)

// Check if the user is eligible for the referral reward
referralCount, err := c.Repo().Referral().GetReferralCountByUserID(user.ID)
if err != nil {
c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
return
}

if !user.ReferralRewardClaimed && referralCount >= referralRewardRequirement {
// Metronome requires an expiration to be passed in, so we set it to 5 years which in
// practice will mean the credits will run out before expiring
expiresAt := time.Now().AddDate(5, 0, 0).Format(time.RFC3339)
err := c.Config().BillingManager.MetronomeClient.CreateCreditsGrant(ctx, proj.UsageID, defaultRewardAmountCents, defaultPaidAmountCents, expiresAt)
if err != nil {
c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
return
}

user.ReferralRewardClaimed = true
_, err = c.Repo().User().UpdateUser(user)
if err != nil {
c.HandleAPIError(w, r, apierrors.NewErrPassThroughToClient(err, http.StatusInternalServerError))
return
}
}

c.WriteResult(w, r, "")
}
46 changes: 0 additions & 46 deletions api/server/handlers/billing/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,52 +58,6 @@ func (c *ListPlansHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.WriteResult(w, r, plan)
}

// ListCreditsHandler is a handler for getting available credits
type ListCreditsHandler struct {
handlers.PorterHandlerWriter
}

// NewListCreditsHandler will create a new ListCreditsHandler
func NewListCreditsHandler(
config *config.Config,
writer shared.ResultWriter,
) *ListCreditsHandler {
return &ListCreditsHandler{
PorterHandlerWriter: handlers.NewDefaultPorterHandler(config, nil, writer),
}
}

func (c *ListCreditsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, span := telemetry.NewSpan(r.Context(), "serve-list-credits")
defer span.End()

proj, _ := ctx.Value(types.ProjectScope).(*models.Project)

if !c.Config().BillingManager.MetronomeConfigLoaded || !proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient) {
c.WriteResult(w, r, "")

telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "metronome-config-exists", Value: c.Config().BillingManager.MetronomeConfigLoaded},
telemetry.AttributeKV{Key: "metronome-enabled", Value: proj.GetFeatureFlag(models.MetronomeEnabled, c.Config().LaunchDarklyClient)},
)
return
}

credits, err := c.Config().BillingManager.MetronomeClient.ListCustomerCredits(ctx, proj.UsageID)
if err != nil {
err := telemetry.Error(ctx, span, err, "error listing credits")
c.HandleAPIError(w, r, apierrors.NewErrInternal(err))
return
}

telemetry.WithAttributes(span,
telemetry.AttributeKV{Key: "metronome-enabled", Value: true},
telemetry.AttributeKV{Key: "usage-id", Value: proj.UsageID},
)

c.WriteResult(w, r, credits)
}

// ListCustomerUsageHandler returns customer usage aggregations like CPU and RAM hours.
type ListCustomerUsageHandler struct {
handlers.PorterHandlerReadWriter
Expand Down
18 changes: 15 additions & 3 deletions api/server/handlers/user/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,17 @@ func (u *UserCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

user.Password = string(hashedPw)

// Generate referral code for user
user.ReferralCode = models.NewReferralCode()

// write the user to the db
user, err = u.Repo().User().CreateUser(user)

if err != nil {
u.HandleAPIError(w, r, apierrors.NewErrInternal(err))
return
}

err = addUserToDefaultProject(u.Config(), user)

if err != nil {
u.HandleAPIError(w, r, apierrors.NewErrInternal(err))
return
Expand All @@ -95,7 +96,19 @@ func (u *UserCreateHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// non-fatal send email verification
if !user.EmailVerified {
err = startEmailVerification(u.Config(), w, r, user)
if err != nil {
u.HandleAPIErrorNoWrite(w, r, apierrors.NewErrInternal(err))
}
}

// create referral if referred by another user
if request.ReferredBy != "" {
referral := &models.Referral{
Code: request.ReferredBy,
ReferredUserID: user.ID,
}

_, err = u.Repo().Referral().CreateReferral(referral)
if err != nil {
u.HandleAPIErrorNoWrite(w, r, apierrors.NewErrInternal(err))
}
Expand Down Expand Up @@ -146,7 +159,6 @@ func addUserToDefaultProject(config *config.Config, user *models.User) error {
Kind: types.RoleAdmin,
},
})

if err != nil {
return err
}
Expand Down
61 changes: 59 additions & 2 deletions api/server/handlers/user/create_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package user_test

import (
"encoding/json"
"net/http"
"testing"

Expand All @@ -9,6 +10,7 @@ import (
"github.com/porter-dev/porter/api/server/shared/apitest"
"github.com/porter-dev/porter/api/types"
"github.com/porter-dev/porter/internal/repository/test"
"github.com/stretchr/testify/assert"
)

func TestCreateUserSuccessful(t *testing.T) {
Expand All @@ -35,7 +37,17 @@ func TestCreateUserSuccessful(t *testing.T) {

handler.ServeHTTP(rr, req)

expUser := &types.CreateUserResponse{
// Use a struct that is the same as types.User but without the
// referral fields. This is because the referral code is randomly
// generated and is tested separately.
expUser := &struct {
ID uint `json:"id"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
CompanyName string `json:"company_name"`
}{
ID: 1,
FirstName: "Mister",
LastName: "Porter",
Expand All @@ -44,7 +56,14 @@ func TestCreateUserSuccessful(t *testing.T) {
EmailVerified: false,
}

gotUser := &types.CreateUserResponse{}
gotUser := &struct {
ID uint `json:"id"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
FirstName string `json:"first_name"`
LastName string `json:"last_name"`
CompanyName string `json:"company_name"`
}{}

apitest.AssertResponseExpected(t, rr, expUser, gotUser)
}
Expand Down Expand Up @@ -191,3 +210,41 @@ func TestFailingCreateSessionMethod(t *testing.T) {

apitest.AssertResponseInternalServerError(t, rr)
}

func TestCreateUserReferralCode(t *testing.T) {
req, rr := apitest.GetRequestAndRecorder(
t,
string(types.HTTPVerbPost),
"/api/users",
&types.CreateUserRequest{
FirstName: "Mister",
LastName: "Porter",
CompanyName: "Porter Technologies, Inc.",
Email: "[email protected]",
Password: "somepassword",
},
)

config := apitest.LoadConfig(t)

handler := user.NewUserCreateHandler(
config,
shared.NewDefaultRequestDecoderValidator(config.Logger, config.Alerter),
shared.NewDefaultResultWriter(config.Logger, config.Alerter),
)

handler.ServeHTTP(rr, req)
gotUser := &types.CreateUserResponse{}

// apitest.AssertResponseExpected(t, rr, expUser, gotUser)
err := json.NewDecoder(rr.Body).Decode(gotUser)
if err != nil {
t.Fatal(err)
}

// This is the default lenth of a shortuuid
desiredLenth := 22
assert.NotEmpty(t, gotUser.ReferralCode, "referral code should not be empty")
assert.Len(t, gotUser.ReferralCode, desiredLenth, "referral code should be 22 characters long")
assert.Equal(t, gotUser.ReferralRewardClaimed, false, "referral reward claimed should be false for new user")
}
6 changes: 3 additions & 3 deletions api/server/handlers/user/github_callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ func (p *UserOAuthGithubCallbackHandler) ServeHTTP(w http.ResponseWriter, r *htt
// non-fatal send email verification
if !user.EmailVerified {
err = startEmailVerification(p.Config(), w, r, user)

if err != nil {
p.HandleAPIErrorNoWrite(w, r, apierrors.NewErrInternal(err))
}
Expand Down Expand Up @@ -147,14 +146,15 @@ func upsertUserFromToken(config *config.Config, tok *oauth2.Token) (*models.User
GithubUserID: githubUser.GetID(),
}

user, err = config.Repo.User().CreateUser(user)
// Generate referral code for user
user.ReferralCode = models.NewReferralCode()

user, err = config.Repo.User().CreateUser(user)
if err != nil {
return nil, err
}

err = addUserToDefaultProject(config, user)

if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions api/server/handlers/user/google_callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ func (p *UserOAuthGoogleCallbackHandler) ServeHTTP(w http.ResponseWriter, r *htt
// non-fatal send email verification
if !user.EmailVerified {
err = startEmailVerification(p.Config(), w, r, user)

if err != nil {
p.HandleAPIErrorNoWrite(w, r, apierrors.NewErrInternal(err))
}
Expand Down Expand Up @@ -133,14 +132,15 @@ func upsertGoogleUserFromToken(config *config.Config, tok *oauth2.Token) (*model
GoogleUserID: gInfo.Sub,
}

user, err = config.Repo.User().CreateUser(user)
// Generate referral code for user
user.ReferralCode = models.NewReferralCode()

user, err = config.Repo.User().CreateUser(user)
if err != nil {
return nil, err
}

err = addUserToDefaultProject(config, user)

if err != nil {
return nil, err
}
Expand Down
Loading
Loading