-
Notifications
You must be signed in to change notification settings - Fork 4
/
parser.go
120 lines (107 loc) · 3.51 KB
/
parser.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package atlas_claims
import (
"context"
"errors"
"strings"
"github.com/golang-jwt/jwt/v4"
"github.com/grpc-ecosystem/go-grpc-middleware/auth"
"github.com/grpc-ecosystem/go-grpc-middleware/util/metautils"
)
const (
SetJwtHeader = "set-authorization"
JwtName = "bearer"
)
var (
errMissingField = errors.New("unable to get field from token")
)
func UnverifiedClaimsFromContext(ctx context.Context) (*Claims, bool) {
bearer, newBearer := AuthBearersFromCtx(ctx)
validClaim, _ := UnverifiedClaimFromBearers([]string{bearer}, []string{newBearer})
return validClaim, validClaim != nil
}
func AuthBearersFromCtx(ctx context.Context) (string, string) {
var newBearer string
bearer, _ := grpc_auth.AuthFromMD(ctx, JwtName)
val := metautils.ExtractIncoming(ctx).Get(SetJwtHeader)
if val != "" {
splits := strings.SplitN(val, " ", 2)
if len(splits) >= 2 && strings.ToLower(splits[0]) == strings.ToLower(JwtName) {
newBearer = splits[1]
}
}
return bearer, newBearer
}
func UnverifiedClaimFromBearers(bearer, newBearer []string) (*Claims, []error) {
validBearerClaim, bearerErrorList := ParseUnverifiedClaimsFromJwtStrings(bearer)
validNewBearerClaim, newBearerErrorList := ParseUnverifiedClaimsFromJwtStrings(newBearer)
if len(newBearerErrorList) > 0 || len(bearerErrorList) > 0 {
//fishy Should not have multiple newBearers
}
// Take the new bearer if possible.
if validNewBearerClaim != nil {
return validNewBearerClaim, nil
} else if validBearerClaim != nil {
return validBearerClaim, nil
} else {
return nil, append(bearerErrorList, newBearerErrorList...)
}
}
func ParseUnverifiedClaimsFromJwtStrings(jwtStrings []string) (validClaim *Claims, errList []error) {
validClaim, _, errList = ParseUnverifiedClaimsFromJwtStringsRaw(jwtStrings)
return
}
// ParseUnverifiedClaimsFromJwtStringsRaw will return the raw (unmarshaled) jwt in addition to the valid claim.
func ParseUnverifiedClaimsFromJwtStringsRaw(jwtStrings []string) (validClaim *Claims, raw string, errList []error) {
for _, jwtString := range jwtStrings {
claims := &Claims{}
parser := &jwt.Parser{}
_, _, err := parser.ParseUnverified(jwtString, claims)
// We use the most recent token
if err != nil {
errList = append(errList, err)
} else {
if validClaim == nil || (claims.IssuedAt > validClaim.IssuedAt) {
validClaim = claims
raw = jwtString
}
}
}
return
}
// GetAccountID will return the account ID from the context.
func GetAccountID(ctx context.Context) (string, error) {
accountID := ""
claims, ok := UnverifiedClaimsFromContext(ctx)
if ok {
accountID = claims.AccountId
}
if !ok || accountID == "" {
return "", errMissingField
}
return accountID, nil
}
// GetCompartmentID will return the compartment ID from the context.
// Defaults to empty if compartment ID claim is not present in the JWT.
func GetCompartmentID(ctx context.Context) (string, bool) {
compartmentID := ""
claims, ok := UnverifiedClaimsFromContext(ctx)
if ok {
compartmentID = claims.CompartmentID
}
return compartmentID, ok
}
// GetAccountAndCompartmentID will return the account ID and compartment ID from the context.
// Defaults to empty compartment if compartment ID claim is not present in the JWT.
func GetAccountAndCompartmentID(ctx context.Context) (string, string, error) {
accountID := ""
compartmentID := ""
claims, ok := UnverifiedClaimsFromContext(ctx)
if ok {
accountID = claims.AccountId
compartmentID = claims.CompartmentID
}
if !ok || accountID == "" {
return "", "", errMissingField
}
return accountID, compartmentID, nil
}