1 package jws
2
3 import (
4 "bytes"
5 "context"
6
7 "github.com/lestrrat-go/jwx/internal/base64"
8 "github.com/lestrrat-go/jwx/internal/json"
9 "github.com/lestrrat-go/jwx/internal/pool"
10 "github.com/lestrrat-go/jwx/jwk"
11 "github.com/pkg/errors"
12 )
13
14 type collectRawCtx struct{}
15
16 func (collectRawCtx) CollectRaw() bool {
17 return true
18 }
19
20 func NewSignature() *Signature {
21 return &Signature{}
22 }
23
24 func (s *Signature) DecodeCtx() DecodeCtx {
25 return s.dc
26 }
27
28 func (s *Signature) SetDecodeCtx(dc DecodeCtx) {
29 s.dc = dc
30 }
31
32 func (s Signature) PublicHeaders() Headers {
33 return s.headers
34 }
35
36 func (s *Signature) SetPublicHeaders(v Headers) *Signature {
37 s.headers = v
38 return s
39 }
40
41 func (s Signature) ProtectedHeaders() Headers {
42 return s.protected
43 }
44
45 func (s *Signature) SetProtectedHeaders(v Headers) *Signature {
46 s.protected = v
47 return s
48 }
49
50 func (s Signature) Signature() []byte {
51 return s.signature
52 }
53
54 func (s *Signature) SetSignature(v []byte) *Signature {
55 s.signature = v
56 return s
57 }
58
59 type signatureUnmarshalProbe struct {
60 Header Headers `json:"header,omitempty"`
61 Protected *string `json:"protected,omitempty"`
62 Signature *string `json:"signature,omitempty"`
63 }
64
65 func (s *Signature) UnmarshalJSON(data []byte) error {
66 var sup signatureUnmarshalProbe
67 sup.Header = NewHeaders()
68 if err := json.Unmarshal(data, &sup); err != nil {
69 return errors.Wrap(err, `failed to unmarshal signature into temporary struct`)
70 }
71
72 s.headers = sup.Header
73 if buf := sup.Protected; buf != nil {
74 src := []byte(*buf)
75 if !bytes.HasPrefix(src, []byte{'{'}) {
76 decoded, err := base64.Decode(src)
77 if err != nil {
78 return errors.Wrap(err, `failed to base64 decode protected headers`)
79 }
80 src = decoded
81 }
82
83 prt := NewHeaders()
84
85 prt.(*stdHeaders).SetDecodeCtx(s.DecodeCtx())
86 if err := json.Unmarshal(src, prt); err != nil {
87 return errors.Wrap(err, `failed to unmarshal protected headers`)
88 }
89
90 prt.(*stdHeaders).SetDecodeCtx(nil)
91 s.protected = prt
92 }
93
94 decoded, err := base64.DecodeString(*sup.Signature)
95 if err != nil {
96 return errors.Wrap(err, `failed to base decode signature`)
97 }
98 s.signature = decoded
99 return nil
100 }
101
102
103
104
105
106
107
108 func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte, []byte, error) {
109 ctx, cancel := context.WithCancel(context.Background())
110 defer cancel()
111
112 hdrs, err := mergeHeaders(ctx, s.headers, s.protected)
113 if err != nil {
114 return nil, nil, errors.Wrap(err, `failed to merge headers`)
115 }
116
117 if err := hdrs.Set(AlgorithmKey, signer.Algorithm()); err != nil {
118 return nil, nil, errors.Wrap(err, `failed to set "alg"`)
119 }
120
121
122 if jwkKey, ok := key.(jwk.Key); ok {
123
124 if kid := jwkKey.KeyID(); kid != "" {
125 if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil {
126 return nil, nil, errors.Wrap(err, `set key ID from jwk.Key`)
127 }
128 }
129 }
130 hdrbuf, err := json.Marshal(hdrs)
131 if err != nil {
132 return nil, nil, errors.Wrap(err, `failed to marshal headers`)
133 }
134
135 buf := pool.GetBytesBuffer()
136 defer pool.ReleaseBytesBuffer(buf)
137
138 buf.WriteString(base64.EncodeToString(hdrbuf))
139 buf.WriteByte('.')
140
141 var plen int
142 b64 := getB64Value(hdrs)
143 if b64 {
144 encoded := base64.EncodeToString(payload)
145 plen = len(encoded)
146 buf.WriteString(encoded)
147 } else {
148 if !s.detached {
149 if bytes.Contains(payload, []byte{'.'}) {
150 return nil, nil, errors.New(`payload must not contain a "."`)
151 }
152 }
153 plen = len(payload)
154 buf.Write(payload)
155 }
156
157 signature, err := signer.Sign(buf.Bytes(), key)
158 if err != nil {
159 return nil, nil, errors.Wrap(err, `failed to sign payload`)
160 }
161 s.signature = signature
162
163
164 if s.detached {
165 buf.Truncate(buf.Len() - plen)
166 }
167
168 buf.WriteByte('.')
169 buf.WriteString(base64.EncodeToString(signature))
170 ret := make([]byte, buf.Len())
171 copy(ret, buf.Bytes())
172
173 return signature, ret, nil
174 }
175
176 func NewMessage() *Message {
177 return &Message{}
178 }
179
180
181
182 func (m *Message) clearRaw() {
183 for _, sig := range m.signatures {
184 if protected := sig.protected; protected != nil {
185 if cr, ok := protected.(*stdHeaders); ok {
186 cr.raw = nil
187 }
188 }
189 }
190 }
191
192 func (m *Message) SetDecodeCtx(dc DecodeCtx) {
193 m.dc = dc
194 }
195
196 func (m *Message) DecodeCtx() DecodeCtx {
197 return m.dc
198 }
199
200
201 func (m Message) Payload() []byte {
202 return m.payload
203 }
204
205 func (m *Message) SetPayload(v []byte) *Message {
206 m.payload = v
207 return m
208 }
209
210 func (m Message) Signatures() []*Signature {
211 return m.signatures
212 }
213
214 func (m *Message) AppendSignature(v *Signature) *Message {
215 m.signatures = append(m.signatures, v)
216 return m
217 }
218
219 func (m *Message) ClearSignatures() *Message {
220 m.signatures = nil
221 return m
222 }
223
224
225
226 func (m Message) LookupSignature(kid string) []*Signature {
227 var sigs []*Signature
228 for _, sig := range m.signatures {
229 if hdr := sig.PublicHeaders(); hdr != nil {
230 hdrKeyID := hdr.KeyID()
231 if hdrKeyID == kid {
232 sigs = append(sigs, sig)
233 continue
234 }
235 }
236
237 if hdr := sig.ProtectedHeaders(); hdr != nil {
238 hdrKeyID := hdr.KeyID()
239 if hdrKeyID == kid {
240 sigs = append(sigs, sig)
241 continue
242 }
243 }
244 }
245 return sigs
246 }
247
248
249
250
251 type messageUnmarshalProbe struct {
252 Payload *string `json:"payload"`
253 Signatures []json.RawMessage `json:"signatures,omitempty"`
254 Header Headers `json:"header,omitempty"`
255 Protected *string `json:"protected,omitempty"`
256 Signature *string `json:"signature,omitempty"`
257 }
258
259 func (m *Message) UnmarshalJSON(buf []byte) error {
260 m.payload = nil
261 m.signatures = nil
262 m.b64 = true
263
264 var mup messageUnmarshalProbe
265 mup.Header = NewHeaders()
266 if err := json.Unmarshal(buf, &mup); err != nil {
267 return errors.Wrap(err, `failed to unmarshal into temporary structure`)
268 }
269
270 b64 := true
271 if mup.Signature == nil {
272 if len(mup.Signatures) == 0 {
273 return errors.New(`required field "signatures" not present`)
274 }
275
276 m.signatures = make([]*Signature, 0, len(mup.Signatures))
277 for i, rawsig := range mup.Signatures {
278 var sig Signature
279 sig.SetDecodeCtx(m.DecodeCtx())
280 if err := json.Unmarshal(rawsig, &sig); err != nil {
281 return errors.Wrapf(err, `failed to unmarshal signature #%d`, i+1)
282 }
283 sig.SetDecodeCtx(nil)
284
285 if i == 0 {
286 if !getB64Value(sig.protected) {
287 b64 = false
288 }
289 } else {
290 if b64 != getB64Value(sig.protected) {
291 return errors.Errorf(`b64 value must be the same for all signatures`)
292 }
293 }
294
295 m.signatures = append(m.signatures, &sig)
296 }
297 } else {
298 if len(mup.Signatures) != 0 {
299 return errors.New(`invalid format ("signatures" and "signature" keys cannot both be present)`)
300 }
301
302 var sig Signature
303 sig.headers = mup.Header
304 if src := mup.Protected; src != nil {
305 decoded, err := base64.DecodeString(*src)
306 if err != nil {
307 return errors.Wrap(err, `failed to base64 decode flattened protected headers`)
308 }
309 prt := NewHeaders()
310
311 prt.(*stdHeaders).SetDecodeCtx(m.DecodeCtx())
312 if err := json.Unmarshal(decoded, prt); err != nil {
313 return errors.Wrap(err, `failed to unmarshal flattened protected headers`)
314 }
315
316 prt.(*stdHeaders).SetDecodeCtx(nil)
317 sig.protected = prt
318 }
319
320 decoded, err := base64.DecodeString(*mup.Signature)
321 if err != nil {
322 return errors.Wrap(err, `failed to base64 decode flattened signature`)
323 }
324 sig.signature = decoded
325
326 m.signatures = []*Signature{&sig}
327 b64 = getB64Value(sig.protected)
328 }
329
330 if mup.Payload != nil {
331 if !b64 {
332 m.payload = []byte(*mup.Payload)
333 } else {
334 decoded, err := base64.DecodeString(*mup.Payload)
335 if err != nil {
336 return errors.Wrap(err, `failed to base64 decode payload`)
337 }
338 m.payload = decoded
339 }
340 }
341 m.b64 = b64
342 return nil
343 }
344
345 func (m Message) MarshalJSON() ([]byte, error) {
346 if len(m.signatures) == 1 {
347 return m.marshalFlattened()
348 }
349 return m.marshalFull()
350 }
351
352 func (m Message) marshalFlattened() ([]byte, error) {
353 buf := pool.GetBytesBuffer()
354 defer pool.ReleaseBytesBuffer(buf)
355
356 sig := m.signatures[0]
357
358 buf.WriteRune('{')
359 var wrote bool
360
361 if hdr := sig.headers; hdr != nil {
362 hdrjs, err := hdr.MarshalJSON()
363 if err != nil {
364 return nil, errors.Wrap(err, `failed to marshal "header" (flattened format)`)
365 }
366 buf.WriteString(`"header":`)
367 buf.Write(hdrjs)
368 wrote = true
369 }
370
371 if wrote {
372 buf.WriteRune(',')
373 }
374 buf.WriteString(`"payload":"`)
375 buf.WriteString(base64.EncodeToString(m.payload))
376 buf.WriteRune('"')
377
378 if protected := sig.protected; protected != nil {
379 protectedbuf, err := protected.MarshalJSON()
380 if err != nil {
381 return nil, errors.Wrap(err, `failed to marshal "protected" (flattened format)`)
382 }
383 buf.WriteString(`,"protected":"`)
384 buf.WriteString(base64.EncodeToString(protectedbuf))
385 buf.WriteRune('"')
386 }
387
388 buf.WriteString(`,"signature":"`)
389 buf.WriteString(base64.EncodeToString(sig.signature))
390 buf.WriteRune('"')
391 buf.WriteRune('}')
392
393 ret := make([]byte, buf.Len())
394 copy(ret, buf.Bytes())
395 return ret, nil
396 }
397
398 func (m Message) marshalFull() ([]byte, error) {
399 buf := pool.GetBytesBuffer()
400 defer pool.ReleaseBytesBuffer(buf)
401
402 buf.WriteString(`{"payload":"`)
403 buf.WriteString(base64.EncodeToString(m.payload))
404 buf.WriteString(`","signatures":[`)
405 for i, sig := range m.signatures {
406 if i > 0 {
407 buf.WriteRune(',')
408 }
409
410 buf.WriteRune('{')
411 var wrote bool
412 if hdr := sig.headers; hdr != nil {
413 hdrbuf, err := hdr.MarshalJSON()
414 if err != nil {
415 return nil, errors.Wrapf(err, `failed to marshal "header" for signature #%d`, i+1)
416 }
417 buf.WriteString(`"header":`)
418 buf.Write(hdrbuf)
419 wrote = true
420 }
421
422 if protected := sig.protected; protected != nil {
423 protectedbuf, err := protected.MarshalJSON()
424 if err != nil {
425 return nil, errors.Wrapf(err, `failed to marshal "protected" for signature #%d`, i+1)
426 }
427 if wrote {
428 buf.WriteRune(',')
429 }
430 buf.WriteString(`"protected":"`)
431 buf.WriteString(base64.EncodeToString(protectedbuf))
432 buf.WriteRune('"')
433 wrote = true
434 }
435
436 if wrote {
437 buf.WriteRune(',')
438 }
439 buf.WriteString(`"signature":"`)
440 buf.WriteString(base64.EncodeToString(sig.signature))
441 buf.WriteString(`"}`)
442 }
443 buf.WriteString(`]}`)
444
445 ret := make([]byte, buf.Len())
446 copy(ret, buf.Bytes())
447 return ret, nil
448 }
449
View as plain text