1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package jwt
16
17 import (
18 "bytes"
19 "crypto"
20 "crypto/rand"
21 "crypto/rsa"
22 "crypto/sha256"
23 "encoding/base64"
24 "encoding/json"
25 "errors"
26 "fmt"
27 "strings"
28 "time"
29 )
30
31 const (
32
33 HeaderAlgRSA256 = "RS256"
34
35 HeaderAlgES256 = "ES256"
36
37 HeaderType = "JWT"
38 )
39
40
41 type Header struct {
42 Algorithm string `json:"alg"`
43 Type string `json:"typ"`
44 KeyID string `json:"kid"`
45 }
46
47 func (h *Header) encode() (string, error) {
48 b, err := json.Marshal(h)
49 if err != nil {
50 return "", err
51 }
52 return base64.RawURLEncoding.EncodeToString(b), nil
53 }
54
55
56 type Claims struct {
57
58 Iss string `json:"iss"`
59
60 Scope string `json:"scope,omitempty"`
61
62 Exp int64 `json:"exp"`
63
64 Iat int64 `json:"iat"`
65
66 Aud string `json:"aud"`
67
68 Sub string `json:"sub,omitempty"`
69
70 AdditionalClaims map[string]interface{} `json:"-"`
71 }
72
73 func (c *Claims) encode() (string, error) {
74
75 now := time.Now().Add(-10 * time.Second)
76 if c.Iat == 0 {
77 c.Iat = now.Unix()
78 }
79 if c.Exp == 0 {
80 c.Exp = now.Add(time.Hour).Unix()
81 }
82 if c.Exp < c.Iat {
83 return "", fmt.Errorf("jwt: invalid Exp = %d; must be later than Iat = %d", c.Exp, c.Iat)
84 }
85
86 b, err := json.Marshal(c)
87 if err != nil {
88 return "", err
89 }
90
91 if len(c.AdditionalClaims) == 0 {
92 return base64.RawURLEncoding.EncodeToString(b), nil
93 }
94
95
96 prv, err := json.Marshal(c.AdditionalClaims)
97 if err != nil {
98 return "", fmt.Errorf("invalid map of additional claims %v: %w", c.AdditionalClaims, err)
99 }
100
101
102 if !bytes.HasSuffix(b, []byte{'}'}) {
103 return "", fmt.Errorf("invalid JSON %s", b)
104 }
105 if !bytes.HasPrefix(prv, []byte{'{'}) {
106 return "", fmt.Errorf("invalid JSON %s", prv)
107 }
108 b[len(b)-1] = ','
109 b = append(b, prv[1:]...)
110 return base64.RawURLEncoding.EncodeToString(b), nil
111 }
112
113
114 func EncodeJWS(header *Header, c *Claims, key *rsa.PrivateKey) (string, error) {
115 head, err := header.encode()
116 if err != nil {
117 return "", err
118 }
119 claims, err := c.encode()
120 if err != nil {
121 return "", err
122 }
123 ss := fmt.Sprintf("%s.%s", head, claims)
124 h := sha256.New()
125 h.Write([]byte(ss))
126 sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h.Sum(nil))
127 if err != nil {
128 return "", err
129 }
130 return fmt.Sprintf("%s.%s", ss, base64.RawURLEncoding.EncodeToString(sig)), nil
131 }
132
133
134 func DecodeJWS(payload string) (*Claims, error) {
135
136 s := strings.Split(payload, ".")
137 if len(s) < 2 {
138 return nil, errors.New("invalid token received")
139 }
140 decoded, err := base64.RawURLEncoding.DecodeString(s[1])
141 if err != nil {
142 return nil, err
143 }
144 c := &Claims{}
145 if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(c); err != nil {
146 return nil, err
147 }
148 if err := json.NewDecoder(bytes.NewBuffer(decoded)).Decode(&c.AdditionalClaims); err != nil {
149 return nil, err
150 }
151 return c, err
152 }
153
154
155
156 func VerifyJWS(token string, key *rsa.PublicKey) error {
157 parts := strings.Split(token, ".")
158 if len(parts) != 3 {
159 return errors.New("jwt: invalid token received, token must have 3 parts")
160 }
161
162 signedContent := parts[0] + "." + parts[1]
163 signatureString, err := base64.RawURLEncoding.DecodeString(parts[2])
164 if err != nil {
165 return err
166 }
167
168 h := sha256.New()
169 h.Write([]byte(signedContent))
170 return rsa.VerifyPKCS1v15(key, crypto.SHA256, h.Sum(nil), signatureString)
171 }
172
View as plain text