Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: null OAuth email behavior #259

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions app/data/mock/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (s *accountStore) AddOauthAccount(accountID int, provider, providerID, emai

now := time.Now()
oauthAccount := &models.OauthAccount{
Email: email,
Email: &email,
AccountID: accountID,
Provider: provider,
ProviderID: providerID,
Expand All @@ -130,7 +130,7 @@ func (s *accountStore) UpdateOauthAccount(accountID int, provider, email string)

for i, oauthAccount := range oauthAccounts {
if oauthAccount.Provider == provider {
s.oauthAccountsByID[accountID][i].Email = email
s.oauthAccountsByID[accountID][i].Email = &email
return true, nil
}
}
Expand Down
23 changes: 23 additions & 0 deletions app/data/mysql/account_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,27 @@ func TestAccountStore(t *testing.T) {
db.MustExec("TRUNCATE oauth_accounts")
tester(t, store)
}

t.Run("handle oauth email with null value", func(t *testing.T) {
account, err := store.Create("migrated-user", []byte("old"))
require.NoError(t, err)

err = store.AddOauthAccount(account.ID, "provider", "provider_id", "", "token")
require.NoError(t, err)

result, err := db.Exec("UPDATE oauth_accounts SET email = NULL WHERE account_id = ?", account.ID)
require.NoError(t, err)

rowsAffected, err := result.RowsAffected()
require.NoError(t, err)

require.Equal(t, int64(1), rowsAffected)

oAccounts, err := store.GetOauthAccounts(account.ID)
require.NoError(t, err)

require.Len(t, oAccounts, 1)
require.True(t, oAccounts[0].Email == nil)
require.Equal(t, oAccounts[0].GetEmail(), "")
})
}
23 changes: 23 additions & 0 deletions app/data/postgres/account_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,27 @@ func TestAccountStore(t *testing.T) {
db.MustExec("TRUNCATE oauth_accounts")
tester(t, store)
}

t.Run("handle oauth email with null value", func(t *testing.T) {
account, err := store.Create("migrated-user", []byte("old"))
require.NoError(t, err)

err = store.AddOauthAccount(account.ID, "provider", "provider_id", "", "token")
require.NoError(t, err)

result, err := db.Exec("UPDATE oauth_accounts SET email = NULL WHERE account_id = $1", account.ID)
require.NoError(t, err)

rowsAffected, err := result.RowsAffected()
require.NoError(t, err)

require.Equal(t, int64(1), rowsAffected)

oAccounts, err := store.GetOauthAccounts(account.ID)
require.NoError(t, err)

require.Len(t, oAccounts, 1)
require.True(t, oAccounts[0].Email == nil)
require.Equal(t, oAccounts[0].GetEmail(), "")
})
}
12 changes: 10 additions & 2 deletions app/models/oauth_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,20 @@ type OauthAccount struct {
AccountID int `db:"account_id"`
Provider string
ProviderID string `db:"provider_id"`
Email string `db:"email"`
Email *string `db:"email"`
AccessToken string `db:"access_token"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}

func (a OauthAccount) GetEmail() string {
if a.Email != nil {
return *a.Email
}

return ""
}

func (o OauthAccount) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Provider string `json:"provider"`
Expand All @@ -24,6 +32,6 @@ func (o OauthAccount) MarshalJSON() ([]byte, error) {
}{
Provider: o.Provider,
ProviderID: o.ProviderID,
Email: o.Email,
Email: o.GetEmail(),
})
}
4 changes: 2 additions & 2 deletions app/services/account_getter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ func TestAccountGetter(t *testing.T) {
require.Equal(t, 2, len(oAccounts))
require.Equal(t, "test", oAccounts[0].Provider)
require.Equal(t, "ID1", oAccounts[0].ProviderID)
require.Equal(t, "email1", oAccounts[0].Email)
require.Equal(t, "email1", oAccounts[0].GetEmail())

require.Equal(t, "trial", oAccounts[1].Provider)
require.Equal(t, "ID2", oAccounts[1].ProviderID)
require.Equal(t, "email2", oAccounts[1].Email)
require.Equal(t, "email2", oAccounts[1].GetEmail())
})
}
2 changes: 1 addition & 1 deletion app/services/identity_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func updateUserInfo(accountStore data.AccountStore, accountID int, providerName
continue
}

if oAccount.Email != providerUser.Email {
if oAccount.GetEmail() != providerUser.Email {
_, err = accountStore.UpdateOauthAccount(accountID, oAccount.Provider, providerUser.Email)
if err != nil {
return errors.Wrap(err, "UpdateOauthAccount")
Expand Down
4 changes: 2 additions & 2 deletions app/services/identity_reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestIdentityReconciler(t *testing.T) {
oAccounts, err := store.GetOauthAccounts(account.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(oAccounts))
assert.Equal(t, email, oAccounts[0].Email)
assert.Equal(t, email, oAccounts[0].GetEmail())
})

t.Run("update oauth email when is outdated", func(t *testing.T) {
Expand All @@ -123,6 +123,6 @@ func TestIdentityReconciler(t *testing.T) {
oAccounts, err := store.GetOauthAccounts(account.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(oAccounts))
assert.Equal(t, email, oAccounts[0].Email)
assert.Equal(t, email, oAccounts[0].GetEmail())
})
}
2 changes: 1 addition & 1 deletion server/handlers/get_account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func assertGetAccountResponse(t *testing.T, res *http.Response, acc *models.Acco
oAccounts = append(oAccounts, map[string]interface{}{
"provider": oAcc.Provider,
"provider_account_id": oAcc.ProviderID,
"email": oAcc.Email,
"email": oAcc.GetEmail(),
})
}

Expand Down
Loading