...

Source file src/github.com/lestrrat-go/jwx/jws/message.go

Documentation: github.com/lestrrat-go/jwx/jws

     1  package jws
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  
     7  	"github.com/lestrrat-go/jwx/internal/base64"
     8  	"github.com/lestrrat-go/jwx/internal/json"
     9  	"github.com/lestrrat-go/jwx/internal/pool"
    10  	"github.com/lestrrat-go/jwx/jwk"
    11  	"github.com/pkg/errors"
    12  )
    13  
    14  type collectRawCtx struct{}
    15  
    16  func (collectRawCtx) CollectRaw() bool {
    17  	return true
    18  }
    19  
    20  func NewSignature() *Signature {
    21  	return &Signature{}
    22  }
    23  
    24  func (s *Signature) DecodeCtx() DecodeCtx {
    25  	return s.dc
    26  }
    27  
    28  func (s *Signature) SetDecodeCtx(dc DecodeCtx) {
    29  	s.dc = dc
    30  }
    31  
    32  func (s Signature) PublicHeaders() Headers {
    33  	return s.headers
    34  }
    35  
    36  func (s *Signature) SetPublicHeaders(v Headers) *Signature {
    37  	s.headers = v
    38  	return s
    39  }
    40  
    41  func (s Signature) ProtectedHeaders() Headers {
    42  	return s.protected
    43  }
    44  
    45  func (s *Signature) SetProtectedHeaders(v Headers) *Signature {
    46  	s.protected = v
    47  	return s
    48  }
    49  
    50  func (s Signature) Signature() []byte {
    51  	return s.signature
    52  }
    53  
    54  func (s *Signature) SetSignature(v []byte) *Signature {
    55  	s.signature = v
    56  	return s
    57  }
    58  
    59  type signatureUnmarshalProbe struct {
    60  	Header    Headers `json:"header,omitempty"`
    61  	Protected *string `json:"protected,omitempty"`
    62  	Signature *string `json:"signature,omitempty"`
    63  }
    64  
    65  func (s *Signature) UnmarshalJSON(data []byte) error {
    66  	var sup signatureUnmarshalProbe
    67  	sup.Header = NewHeaders()
    68  	if err := json.Unmarshal(data, &sup); err != nil {
    69  		return errors.Wrap(err, `failed to unmarshal signature into temporary struct`)
    70  	}
    71  
    72  	s.headers = sup.Header
    73  	if buf := sup.Protected; buf != nil {
    74  		src := []byte(*buf)
    75  		if !bytes.HasPrefix(src, []byte{'{'}) {
    76  			decoded, err := base64.Decode(src)
    77  			if err != nil {
    78  				return errors.Wrap(err, `failed to base64 decode protected headers`)
    79  			}
    80  			src = decoded
    81  		}
    82  
    83  		prt := NewHeaders()
    84  		//nolint:forcetypeassert
    85  		prt.(*stdHeaders).SetDecodeCtx(s.DecodeCtx())
    86  		if err := json.Unmarshal(src, prt); err != nil {
    87  			return errors.Wrap(err, `failed to unmarshal protected headers`)
    88  		}
    89  		//nolint:forcetypeassert
    90  		prt.(*stdHeaders).SetDecodeCtx(nil)
    91  		s.protected = prt
    92  	}
    93  
    94  	decoded, err := base64.DecodeString(*sup.Signature)
    95  	if err != nil {
    96  		return errors.Wrap(err, `failed to base decode signature`)
    97  	}
    98  	s.signature = decoded
    99  	return nil
   100  }
   101  
   102  // Sign populates the signature field, with a signature generated by
   103  // given the signer object and payload.
   104  //
   105  // The first return value is the raw signature in binary format.
   106  // The second return value s the full three-segment signature
   107  // (e.g. "eyXXXX.XXXXX.XXXX")
   108  func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte, []byte, error) {
   109  	ctx, cancel := context.WithCancel(context.Background())
   110  	defer cancel()
   111  
   112  	hdrs, err := mergeHeaders(ctx, s.headers, s.protected)
   113  	if err != nil {
   114  		return nil, nil, errors.Wrap(err, `failed to merge headers`)
   115  	}
   116  
   117  	if err := hdrs.Set(AlgorithmKey, signer.Algorithm()); err != nil {
   118  		return nil, nil, errors.Wrap(err, `failed to set "alg"`)
   119  	}
   120  
   121  	// If the key is a jwk.Key instance, obtain the raw key
   122  	if jwkKey, ok := key.(jwk.Key); ok {
   123  		// If we have a key ID specified by this jwk.Key, use that in the header
   124  		if kid := jwkKey.KeyID(); kid != "" {
   125  			if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil {
   126  				return nil, nil, errors.Wrap(err, `set key ID from jwk.Key`)
   127  			}
   128  		}
   129  	}
   130  	hdrbuf, err := json.Marshal(hdrs)
   131  	if err != nil {
   132  		return nil, nil, errors.Wrap(err, `failed to marshal headers`)
   133  	}
   134  
   135  	buf := pool.GetBytesBuffer()
   136  	defer pool.ReleaseBytesBuffer(buf)
   137  
   138  	buf.WriteString(base64.EncodeToString(hdrbuf))
   139  	buf.WriteByte('.')
   140  
   141  	var plen int
   142  	b64 := getB64Value(hdrs)
   143  	if b64 {
   144  		encoded := base64.EncodeToString(payload)
   145  		plen = len(encoded)
   146  		buf.WriteString(encoded)
   147  	} else {
   148  		if !s.detached {
   149  			if bytes.Contains(payload, []byte{'.'}) {
   150  				return nil, nil, errors.New(`payload must not contain a "."`)
   151  			}
   152  		}
   153  		plen = len(payload)
   154  		buf.Write(payload)
   155  	}
   156  
   157  	signature, err := signer.Sign(buf.Bytes(), key)
   158  	if err != nil {
   159  		return nil, nil, errors.Wrap(err, `failed to sign payload`)
   160  	}
   161  	s.signature = signature
   162  
   163  	// Detached payload, this should be removed from the end result
   164  	if s.detached {
   165  		buf.Truncate(buf.Len() - plen)
   166  	}
   167  
   168  	buf.WriteByte('.')
   169  	buf.WriteString(base64.EncodeToString(signature))
   170  	ret := make([]byte, buf.Len())
   171  	copy(ret, buf.Bytes())
   172  
   173  	return signature, ret, nil
   174  }
   175  
   176  func NewMessage() *Message {
   177  	return &Message{}
   178  }
   179  
   180  // Clears the internal raw buffer that was accumulated during
   181  // the verify phase
   182  func (m *Message) clearRaw() {
   183  	for _, sig := range m.signatures {
   184  		if protected := sig.protected; protected != nil {
   185  			if cr, ok := protected.(*stdHeaders); ok {
   186  				cr.raw = nil
   187  			}
   188  		}
   189  	}
   190  }
   191  
   192  func (m *Message) SetDecodeCtx(dc DecodeCtx) {
   193  	m.dc = dc
   194  }
   195  
   196  func (m *Message) DecodeCtx() DecodeCtx {
   197  	return m.dc
   198  }
   199  
   200  // Payload returns the decoded payload
   201  func (m Message) Payload() []byte {
   202  	return m.payload
   203  }
   204  
   205  func (m *Message) SetPayload(v []byte) *Message {
   206  	m.payload = v
   207  	return m
   208  }
   209  
   210  func (m Message) Signatures() []*Signature {
   211  	return m.signatures
   212  }
   213  
   214  func (m *Message) AppendSignature(v *Signature) *Message {
   215  	m.signatures = append(m.signatures, v)
   216  	return m
   217  }
   218  
   219  func (m *Message) ClearSignatures() *Message {
   220  	m.signatures = nil
   221  	return m
   222  }
   223  
   224  // LookupSignature looks up a particular signature entry using
   225  // the `kid` value
   226  func (m Message) LookupSignature(kid string) []*Signature {
   227  	var sigs []*Signature
   228  	for _, sig := range m.signatures {
   229  		if hdr := sig.PublicHeaders(); hdr != nil {
   230  			hdrKeyID := hdr.KeyID()
   231  			if hdrKeyID == kid {
   232  				sigs = append(sigs, sig)
   233  				continue
   234  			}
   235  		}
   236  
   237  		if hdr := sig.ProtectedHeaders(); hdr != nil {
   238  			hdrKeyID := hdr.KeyID()
   239  			if hdrKeyID == kid {
   240  				sigs = append(sigs, sig)
   241  				continue
   242  			}
   243  		}
   244  	}
   245  	return sigs
   246  }
   247  
   248  // This struct is used to first probe for the structure of the
   249  // incoming JSON object. We then decide how to parse it
   250  // from the fields that are populated.
   251  type messageUnmarshalProbe struct {
   252  	Payload    *string           `json:"payload"`
   253  	Signatures []json.RawMessage `json:"signatures,omitempty"`
   254  	Header     Headers           `json:"header,omitempty"`
   255  	Protected  *string           `json:"protected,omitempty"`
   256  	Signature  *string           `json:"signature,omitempty"`
   257  }
   258  
   259  func (m *Message) UnmarshalJSON(buf []byte) error {
   260  	m.payload = nil
   261  	m.signatures = nil
   262  	m.b64 = true
   263  
   264  	var mup messageUnmarshalProbe
   265  	mup.Header = NewHeaders()
   266  	if err := json.Unmarshal(buf, &mup); err != nil {
   267  		return errors.Wrap(err, `failed to unmarshal into temporary structure`)
   268  	}
   269  
   270  	b64 := true
   271  	if mup.Signature == nil { // flattened signature is NOT present
   272  		if len(mup.Signatures) == 0 {
   273  			return errors.New(`required field "signatures" not present`)
   274  		}
   275  
   276  		m.signatures = make([]*Signature, 0, len(mup.Signatures))
   277  		for i, rawsig := range mup.Signatures {
   278  			var sig Signature
   279  			sig.SetDecodeCtx(m.DecodeCtx())
   280  			if err := json.Unmarshal(rawsig, &sig); err != nil {
   281  				return errors.Wrapf(err, `failed to unmarshal signature #%d`, i+1)
   282  			}
   283  			sig.SetDecodeCtx(nil)
   284  
   285  			if i == 0 {
   286  				if !getB64Value(sig.protected) {
   287  					b64 = false
   288  				}
   289  			} else {
   290  				if b64 != getB64Value(sig.protected) {
   291  					return errors.Errorf(`b64 value must be the same for all signatures`)
   292  				}
   293  			}
   294  
   295  			m.signatures = append(m.signatures, &sig)
   296  		}
   297  	} else { // .signature is present, it's a flattened structure
   298  		if len(mup.Signatures) != 0 {
   299  			return errors.New(`invalid format ("signatures" and "signature" keys cannot both be present)`)
   300  		}
   301  
   302  		var sig Signature
   303  		sig.headers = mup.Header
   304  		if src := mup.Protected; src != nil {
   305  			decoded, err := base64.DecodeString(*src)
   306  			if err != nil {
   307  				return errors.Wrap(err, `failed to base64 decode flattened protected headers`)
   308  			}
   309  			prt := NewHeaders()
   310  			//nolint:forcetypeassert
   311  			prt.(*stdHeaders).SetDecodeCtx(m.DecodeCtx())
   312  			if err := json.Unmarshal(decoded, prt); err != nil {
   313  				return errors.Wrap(err, `failed to unmarshal flattened protected headers`)
   314  			}
   315  			//nolint:forcetypeassert
   316  			prt.(*stdHeaders).SetDecodeCtx(nil)
   317  			sig.protected = prt
   318  		}
   319  
   320  		decoded, err := base64.DecodeString(*mup.Signature)
   321  		if err != nil {
   322  			return errors.Wrap(err, `failed to base64 decode flattened signature`)
   323  		}
   324  		sig.signature = decoded
   325  
   326  		m.signatures = []*Signature{&sig}
   327  		b64 = getB64Value(sig.protected)
   328  	}
   329  
   330  	if mup.Payload != nil {
   331  		if !b64 { // NOT base64 encoded
   332  			m.payload = []byte(*mup.Payload)
   333  		} else {
   334  			decoded, err := base64.DecodeString(*mup.Payload)
   335  			if err != nil {
   336  				return errors.Wrap(err, `failed to base64 decode payload`)
   337  			}
   338  			m.payload = decoded
   339  		}
   340  	}
   341  	m.b64 = b64
   342  	return nil
   343  }
   344  
   345  func (m Message) MarshalJSON() ([]byte, error) {
   346  	if len(m.signatures) == 1 {
   347  		return m.marshalFlattened()
   348  	}
   349  	return m.marshalFull()
   350  }
   351  
   352  func (m Message) marshalFlattened() ([]byte, error) {
   353  	buf := pool.GetBytesBuffer()
   354  	defer pool.ReleaseBytesBuffer(buf)
   355  
   356  	sig := m.signatures[0]
   357  
   358  	buf.WriteRune('{')
   359  	var wrote bool
   360  
   361  	if hdr := sig.headers; hdr != nil {
   362  		hdrjs, err := hdr.MarshalJSON()
   363  		if err != nil {
   364  			return nil, errors.Wrap(err, `failed to marshal "header" (flattened format)`)
   365  		}
   366  		buf.WriteString(`"header":`)
   367  		buf.Write(hdrjs)
   368  		wrote = true
   369  	}
   370  
   371  	if wrote {
   372  		buf.WriteRune(',')
   373  	}
   374  	buf.WriteString(`"payload":"`)
   375  	buf.WriteString(base64.EncodeToString(m.payload))
   376  	buf.WriteRune('"')
   377  
   378  	if protected := sig.protected; protected != nil {
   379  		protectedbuf, err := protected.MarshalJSON()
   380  		if err != nil {
   381  			return nil, errors.Wrap(err, `failed to marshal "protected" (flattened format)`)
   382  		}
   383  		buf.WriteString(`,"protected":"`)
   384  		buf.WriteString(base64.EncodeToString(protectedbuf))
   385  		buf.WriteRune('"')
   386  	}
   387  
   388  	buf.WriteString(`,"signature":"`)
   389  	buf.WriteString(base64.EncodeToString(sig.signature))
   390  	buf.WriteRune('"')
   391  	buf.WriteRune('}')
   392  
   393  	ret := make([]byte, buf.Len())
   394  	copy(ret, buf.Bytes())
   395  	return ret, nil
   396  }
   397  
   398  func (m Message) marshalFull() ([]byte, error) {
   399  	buf := pool.GetBytesBuffer()
   400  	defer pool.ReleaseBytesBuffer(buf)
   401  
   402  	buf.WriteString(`{"payload":"`)
   403  	buf.WriteString(base64.EncodeToString(m.payload))
   404  	buf.WriteString(`","signatures":[`)
   405  	for i, sig := range m.signatures {
   406  		if i > 0 {
   407  			buf.WriteRune(',')
   408  		}
   409  
   410  		buf.WriteRune('{')
   411  		var wrote bool
   412  		if hdr := sig.headers; hdr != nil {
   413  			hdrbuf, err := hdr.MarshalJSON()
   414  			if err != nil {
   415  				return nil, errors.Wrapf(err, `failed to marshal "header" for signature #%d`, i+1)
   416  			}
   417  			buf.WriteString(`"header":`)
   418  			buf.Write(hdrbuf)
   419  			wrote = true
   420  		}
   421  
   422  		if protected := sig.protected; protected != nil {
   423  			protectedbuf, err := protected.MarshalJSON()
   424  			if err != nil {
   425  				return nil, errors.Wrapf(err, `failed to marshal "protected" for signature #%d`, i+1)
   426  			}
   427  			if wrote {
   428  				buf.WriteRune(',')
   429  			}
   430  			buf.WriteString(`"protected":"`)
   431  			buf.WriteString(base64.EncodeToString(protectedbuf))
   432  			buf.WriteRune('"')
   433  			wrote = true
   434  		}
   435  
   436  		if wrote {
   437  			buf.WriteRune(',')
   438  		}
   439  		buf.WriteString(`"signature":"`)
   440  		buf.WriteString(base64.EncodeToString(sig.signature))
   441  		buf.WriteString(`"}`)
   442  	}
   443  	buf.WriteString(`]}`)
   444  
   445  	ret := make([]byte, buf.Len())
   446  	copy(ret, buf.Bytes())
   447  	return ret, nil
   448  }
   449  

View as plain text