1 package jwt
2
3 import (
4 "fmt"
5
6 "github.com/lestrrat-go/jwx/internal/json"
7 "github.com/lestrrat-go/jwx/jwa"
8 "github.com/lestrrat-go/jwx/jwe"
9 "github.com/lestrrat-go/jwx/jws"
10 "github.com/pkg/errors"
11 )
12
13 type SerializeCtx interface {
14 Step() int
15 Nested() bool
16 }
17
18 type serializeCtx struct {
19 step int
20 nested bool
21 }
22
23 func (ctx *serializeCtx) Step() int {
24 return ctx.step
25 }
26
27 func (ctx *serializeCtx) Nested() bool {
28 return ctx.nested
29 }
30
31 type SerializeStep interface {
32 Serialize(SerializeCtx, interface{}) (interface{}, error)
33 }
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56 type Serializer struct {
57 steps []SerializeStep
58 }
59
60
61 func NewSerializer() *Serializer {
62 return &Serializer{}
63 }
64
65
66 func (s *Serializer) Reset() *Serializer {
67 s.steps = nil
68 return s
69 }
70
71
72 func (s *Serializer) Step(step SerializeStep) *Serializer {
73 s.steps = append(s.steps, step)
74 return s
75 }
76
77 type jsonSerializer struct{}
78
79 func (jsonSerializer) Serialize(_ SerializeCtx, v interface{}) (interface{}, error) {
80 token, ok := v.(Token)
81 if !ok {
82 return nil, errors.Errorf(`invalid input: expected jwt.Token`)
83 }
84
85 buf, err := json.Marshal(token)
86 if err != nil {
87 return nil, errors.Errorf(`failed to serialize as JSON`)
88 }
89 return buf, nil
90 }
91
92 type genericHeader interface {
93 Get(string) (interface{}, bool)
94 Set(string, interface{}) error
95 }
96
97 func setTypeOrCty(ctx SerializeCtx, hdrs genericHeader) error {
98
99
100 const typKey = `typ`
101 const ctyKey = `cty`
102
103 if ctx.Step() == 1 {
104
105 if _, ok := hdrs.Get(typKey); !ok {
106 if err := hdrs.Set(typKey, `JWT`); err != nil {
107 return errors.Wrapf(err, `failed to set %s key to "JWT"`, typKey)
108 }
109 }
110 } else {
111 if ctx.Nested() {
112
113
114 if err := hdrs.Set(ctyKey, `JWT`); err != nil {
115 return errors.Wrapf(err, `failed to set %s key to "JWT"`, ctyKey)
116 }
117 }
118 }
119 return nil
120 }
121
122 type jwsSerializer struct {
123 alg jwa.SignatureAlgorithm
124 key interface{}
125 options []SignOption
126 }
127
128 func (s *jwsSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
129 payload, ok := v.([]byte)
130 if !ok {
131 return nil, errors.New(`expected []byte as input`)
132 }
133
134 var hdrs jws.Headers
135
136 for _, option := range s.options {
137 switch option.Ident() {
138 case identJwsHeaders{}:
139 hdrs = option.Value().(jws.Headers)
140 }
141 }
142
143 if hdrs == nil {
144 hdrs = jws.NewHeaders()
145 }
146
147 if err := setTypeOrCty(ctx, hdrs); err != nil {
148 return nil, err
149 }
150
151
152
153 if v, ok := hdrs.Get("b64"); ok {
154 if bval, bok := v.(bool); bok {
155 if !bval {
156 return nil, errors.New(`b64 cannot be false for JWTs`)
157 }
158 }
159 }
160 return jws.Sign(payload, s.alg, s.key, jws.WithHeaders(hdrs))
161 }
162
163 func (s *Serializer) Sign(alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) *Serializer {
164 return s.Step(&jwsSerializer{
165 alg: alg,
166 key: key,
167 options: options,
168 })
169 }
170
171 type jweSerializer struct {
172 keyalg jwa.KeyEncryptionAlgorithm
173 key interface{}
174 contentalg jwa.ContentEncryptionAlgorithm
175 compressalg jwa.CompressionAlgorithm
176 options []EncryptOption
177 }
178
179 func (s *jweSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
180 payload, ok := v.([]byte)
181 if !ok {
182 return nil, fmt.Errorf(`expected []byte as input`)
183 }
184
185 var hdrs jwe.Headers
186
187 for _, option := range s.options {
188 switch option.Ident() {
189 case identJweHeaders{}:
190 hdrs = option.Value().(jwe.Headers)
191 }
192 }
193
194 if hdrs == nil {
195 hdrs = jwe.NewHeaders()
196 }
197
198 if err := setTypeOrCty(ctx, hdrs); err != nil {
199 return nil, err
200 }
201 return jwe.Encrypt(payload, s.keyalg, s.key, s.contentalg, s.compressalg, jwe.WithProtectedHeaders(hdrs))
202 }
203
204 func (s *Serializer) Encrypt(keyalg jwa.KeyEncryptionAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) *Serializer {
205 return s.Step(&jweSerializer{
206 keyalg: keyalg,
207 key: key,
208 contentalg: contentalg,
209 compressalg: compressalg,
210 options: options,
211 })
212 }
213
214 func (s *Serializer) Serialize(t Token) ([]byte, error) {
215 steps := make([]SerializeStep, len(s.steps)+1)
216 steps[0] = jsonSerializer{}
217 for i, step := range s.steps {
218 steps[i+1] = step
219 }
220
221 var ctx serializeCtx
222 ctx.nested = len(s.steps) > 1
223 var payload interface{} = t
224 for i, step := range steps {
225 ctx.step = i
226 v, err := step.Serialize(&ctx, payload)
227 if err != nil {
228 return nil, errors.Wrapf(err, `failed to serialize token at step #%d`, i+1)
229 }
230 payload = v
231 }
232
233 res, ok := payload.([]byte)
234 if !ok {
235 return nil, errors.New(`invalid serialization produced`)
236 }
237
238 return res, nil
239 }
240
View as plain text