1
21
22
23
24
25 package jwt
26
27 import (
28 "context"
29 "crypto"
30 "crypto/ecdsa"
31 "crypto/rsa"
32 "crypto/sha256"
33 "strings"
34
35 "github.com/ory/x/errorsx"
36 "gopkg.in/square/go-jose.v2"
37
38 "github.com/pkg/errors"
39 )
40
41 type JWTStrategy interface {
42 Generate(ctx context.Context, claims MapClaims, header Mapper) (string, string, error)
43 Validate(ctx context.Context, token string) (string, error)
44 Hash(ctx context.Context, in []byte) ([]byte, error)
45 Decode(ctx context.Context, token string) (*Token, error)
46 GetSignature(ctx context.Context, token string) (string, error)
47 GetSigningMethodLength() int
48 }
49
50 var SHA256HashSize = crypto.SHA256.Size()
51
52
53 type RS256JWTStrategy struct {
54 PrivateKey interface{}
55 }
56
57
58 func (j *RS256JWTStrategy) Generate(ctx context.Context, claims MapClaims, header Mapper) (string, string, error) {
59 return generateToken(claims, header, jose.RS256, j.PrivateKey)
60 }
61
62
63 func (j *RS256JWTStrategy) Validate(ctx context.Context, token string) (string, error) {
64 switch t := j.PrivateKey.(type) {
65 case *rsa.PrivateKey:
66 return validateToken(token, t.PublicKey)
67 case jose.OpaqueSigner:
68 return validateToken(token, t.Public().Key)
69 default:
70 return "", errors.New("Unable to validate token. Invalid PrivateKey type")
71 }
72 }
73
74
75 func (j *RS256JWTStrategy) Decode(ctx context.Context, token string) (*Token, error) {
76 switch t := j.PrivateKey.(type) {
77 case *rsa.PrivateKey:
78 return decodeToken(token, t.PublicKey)
79 case jose.OpaqueSigner:
80 return decodeToken(token, t.Public().Key)
81 default:
82 return nil, errors.New("Unable to decode token. Invalid PrivateKey type")
83 }
84 }
85
86
87 func (j *RS256JWTStrategy) GetSignature(ctx context.Context, token string) (string, error) {
88 return getTokenSignature(token)
89 }
90
91
92 func (j *RS256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) {
93 return hashSHA256(in)
94 }
95
96
97 func (j *RS256JWTStrategy) GetSigningMethodLength() int {
98 return SHA256HashSize
99 }
100
101
102 type ES256JWTStrategy struct {
103 PrivateKey interface{}
104 }
105
106
107 func (j *ES256JWTStrategy) Generate(ctx context.Context, claims MapClaims, header Mapper) (string, string, error) {
108 return generateToken(claims, header, jose.ES256, j.PrivateKey)
109 }
110
111
112 func (j *ES256JWTStrategy) Validate(ctx context.Context, token string) (string, error) {
113 switch t := j.PrivateKey.(type) {
114 case *ecdsa.PrivateKey:
115 return validateToken(token, t.PublicKey)
116 case jose.OpaqueSigner:
117 return validateToken(token, t.Public().Key)
118 default:
119 return "", errors.New("Unable to validate token. Invalid PrivateKey type")
120 }
121 }
122
123
124 func (j *ES256JWTStrategy) Decode(ctx context.Context, token string) (*Token, error) {
125 switch t := j.PrivateKey.(type) {
126 case *ecdsa.PrivateKey:
127 return decodeToken(token, t.PublicKey)
128 case jose.OpaqueSigner:
129 return decodeToken(token, t.Public().Key)
130 default:
131 return nil, errors.New("Unable to decode token. Invalid PrivateKey type")
132 }
133 }
134
135
136 func (j *ES256JWTStrategy) GetSignature(ctx context.Context, token string) (string, error) {
137 return getTokenSignature(token)
138 }
139
140
141 func (j *ES256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) {
142 return hashSHA256(in)
143 }
144
145
146 func (j *ES256JWTStrategy) GetSigningMethodLength() int {
147 return SHA256HashSize
148 }
149
150 func generateToken(claims MapClaims, header Mapper, signingMethod jose.SignatureAlgorithm, privateKey interface{}) (rawToken string, sig string, err error) {
151 if header == nil || claims == nil {
152 err = errors.New("Either claims or header is nil.")
153 return
154 }
155
156 token := NewWithClaims(signingMethod, claims)
157 token.Header = assign(token.Header, header.ToMap())
158
159 rawToken, err = token.SignedString(privateKey)
160 if err != nil {
161 return
162 }
163
164 sig, err = getTokenSignature(rawToken)
165 return
166 }
167
168 func decodeToken(token string, verificationKey interface{}) (*Token, error) {
169 keyFunc := func(*Token) (interface{}, error) { return verificationKey, nil }
170 return ParseWithClaims(token, MapClaims{}, keyFunc)
171 }
172
173 func validateToken(tokenStr string, verificationKey interface{}) (string, error) {
174 _, err := decodeToken(tokenStr, verificationKey)
175 if err != nil {
176 return "", err
177 }
178 return getTokenSignature(tokenStr)
179 }
180
181 func getTokenSignature(token string) (string, error) {
182 split := strings.Split(token, ".")
183 if len(split) != 3 {
184 return "", errors.New("Header, body and signature must all be set")
185 }
186 return split[2], nil
187 }
188
189 func hashSHA256(in []byte) ([]byte, error) {
190 hash := sha256.New()
191 _, err := hash.Write(in)
192 if err != nil {
193 return []byte{}, errorsx.WithStack(err)
194 }
195 return hash.Sum([]byte{}), nil
196 }
197
198 func assign(a, b map[string]interface{}) map[string]interface{} {
199 for k, w := range b {
200 if _, ok := a[k]; ok {
201 continue
202 }
203 a[k] = w
204 }
205 return a
206 }
207
View as plain text