...

Source file src/github.com/lestrrat-go/jwx/jws/jws.go

Documentation: github.com/lestrrat-go/jwx/jws

     1  //go:generate ./gen.sh
     2  
     3  // Package jws implements the digital signature on JSON based data
     4  // structures as described in https://tools.ietf.org/html/rfc7515
     5  //
     6  // If you do not care about the details, the only things that you
     7  // would need to use are the following functions:
     8  //
     9  //     jws.Sign(payload, algorithm, key)
    10  //     jws.Verify(encodedjws, algorithm, key)
    11  //
    12  // To sign, simply use `jws.Sign`. `payload` is a []byte buffer that
    13  // contains whatever data you want to sign. `alg` is one of the
    14  // jwa.SignatureAlgorithm constants from package jwa. For RSA and
    15  // ECDSA family of algorithms, you will need to prepare a private key.
    16  // For HMAC family, you just need a []byte value. The `jws.Sign`
    17  // function will return the encoded JWS message on success.
    18  //
    19  // To verify, use `jws.Verify`. It will parse the `encodedjws` buffer
    20  // and verify the result using `algorithm` and `key`. Upon successful
    21  // verification, the original payload is returned, so you can work on it.
    22  package jws
    23  
    24  import (
    25  	"bufio"
    26  	"bytes"
    27  	"context"
    28  	"crypto/ecdsa"
    29  	"crypto/ed25519"
    30  	"crypto/rsa"
    31  	"fmt"
    32  	"io"
    33  	"io/ioutil"
    34  	"net/http"
    35  	"net/url"
    36  	"reflect"
    37  	"strings"
    38  	"sync"
    39  	"unicode"
    40  	"unicode/utf8"
    41  
    42  	"github.com/lestrrat-go/backoff/v2"
    43  	"github.com/lestrrat-go/jwx/internal/base64"
    44  	"github.com/lestrrat-go/jwx/internal/json"
    45  	"github.com/lestrrat-go/jwx/internal/pool"
    46  	"github.com/lestrrat-go/jwx/jwa"
    47  	"github.com/lestrrat-go/jwx/jwk"
    48  	"github.com/lestrrat-go/jwx/x25519"
    49  	"github.com/pkg/errors"
    50  )
    51  
    52  var registry = json.NewRegistry()
    53  
    54  type payloadSigner struct {
    55  	signer    Signer
    56  	key       interface{}
    57  	protected Headers
    58  	public    Headers
    59  }
    60  
    61  func (s *payloadSigner) Sign(payload []byte) ([]byte, error) {
    62  	return s.signer.Sign(payload, s.key)
    63  }
    64  
    65  func (s *payloadSigner) Algorithm() jwa.SignatureAlgorithm {
    66  	return s.signer.Algorithm()
    67  }
    68  
    69  func (s *payloadSigner) ProtectedHeader() Headers {
    70  	return s.protected
    71  }
    72  
    73  func (s *payloadSigner) PublicHeader() Headers {
    74  	return s.public
    75  }
    76  
    77  var signers = make(map[jwa.SignatureAlgorithm]Signer)
    78  var muSigner = &sync.Mutex{}
    79  
    80  // Sign generates a signature for the given payload, and serializes
    81  // it in compact serialization format. In this format you may NOT use
    82  // multiple signers.
    83  //
    84  // The `alg` parameter is the identifier for the signature algorithm
    85  // that should be used.
    86  //
    87  // For the `key` parameter, any of the following is accepted:
    88  // * A "raw" key (e.g. rsa.PrivateKey, ecdsa.PrivateKey, etc)
    89  // * A crypto.Signer
    90  // * A jwk.Key
    91  //
    92  // A `crypto.Signer` is used when the private part of a key is
    93  // kept in an inaccessible location, such as hardware.
    94  // `crypto.Signer` is currently supported for RSA, ECDSA, and EdDSA
    95  // family of algorithms.
    96  //
    97  // If the key is a jwk.Key and the key contains a key ID (`kid` field),
    98  // then it is added to the protected header generated by the signature
    99  //
   100  // The algorithm specified in the `alg` parameter must be able to support
   101  // the type of key you provided, otherwise an error is returned.
   102  //
   103  // If you would like to pass custom headers, use the WithHeaders option.
   104  //
   105  // If the headers contain "b64" field, then the boolean value for the field
   106  // is respected when creating the compact serialization form. That is,
   107  // if you specify a header with `{"b64": false}`, then the payload is
   108  // not base64 encoded.
   109  //
   110  // If you want to use a detached payload, use `jws.WithDetachedPayload()` as
   111  // one of the options. When you use this option, you must always set the
   112  // first parameter (`payload`) to `nil`, or the function will return an error
   113  func Sign(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) ([]byte, error) {
   114  	var hdrs Headers
   115  	var detached bool
   116  	for _, o := range options {
   117  		//nolint:forcetypeassert
   118  		switch o.Ident() {
   119  		case identHeaders{}:
   120  			hdrs = o.Value().(Headers)
   121  		case identDetachedPayload{}:
   122  			detached = true
   123  			if payload != nil {
   124  				return nil, errors.New(`jws.Sign: payload must be nil when jws.WithDetachedPayload() is specified`)
   125  			}
   126  			payload = o.Value().([]byte)
   127  		}
   128  	}
   129  
   130  	muSigner.Lock()
   131  	signer, ok := signers[alg]
   132  	if !ok {
   133  		v, err := NewSigner(alg)
   134  		if err != nil {
   135  			muSigner.Unlock()
   136  			return nil, errors.Wrap(err, `failed to create signer`)
   137  		}
   138  		signers[alg] = v
   139  		signer = v
   140  	}
   141  	muSigner.Unlock()
   142  
   143  	// XXX This is cheating. Ideally `detached` should be passed as a parameter
   144  	// but since this is an exported method, we can't change this without bumping
   145  	// major versions.... But we don't want to do that now, so we will cheat by
   146  	// making it part of the object
   147  	sig := &Signature{
   148  		protected: hdrs,
   149  		detached:  detached,
   150  	}
   151  	_, signature, err := sig.Sign(payload, signer, key)
   152  	if err != nil {
   153  		return nil, errors.Wrap(err, `failed sign payload`)
   154  	}
   155  
   156  	return signature, nil
   157  }
   158  
   159  // SignMulti accepts multiple signers via the options parameter,
   160  // and creates a JWS in JSON serialization format that contains
   161  // signatures from applying aforementioned signers.
   162  //
   163  // Use `jws.WithSigner(...)` to specify values how to generate
   164  // each signature in the `"signatures": [ ... ]` field.
   165  func SignMulti(payload []byte, options ...Option) ([]byte, error) {
   166  	var signers []*payloadSigner
   167  	for _, o := range options {
   168  		//nolint:forcetypeassert
   169  		switch o.Ident() {
   170  		case identPayloadSigner{}:
   171  			signers = append(signers, o.Value().(*payloadSigner))
   172  		}
   173  	}
   174  
   175  	if len(signers) == 0 {
   176  		return nil, errors.New(`no signers provided`)
   177  	}
   178  
   179  	var result Message
   180  
   181  	result.payload = payload
   182  
   183  	result.signatures = make([]*Signature, 0, len(signers))
   184  	for i, signer := range signers {
   185  		protected := signer.ProtectedHeader()
   186  		if protected == nil {
   187  			protected = NewHeaders()
   188  		}
   189  
   190  		if err := protected.Set(AlgorithmKey, signer.Algorithm()); err != nil {
   191  			return nil, errors.Wrap(err, `failed to set "alg" header`)
   192  		}
   193  
   194  		if key, ok := signer.key.(jwk.Key); ok {
   195  			if kid := key.KeyID(); kid != "" {
   196  				if err := protected.Set(KeyIDKey, kid); err != nil {
   197  					return nil, errors.Wrap(err, `failed to set "kid" header`)
   198  				}
   199  			}
   200  		}
   201  		sig := &Signature{
   202  			headers:   signer.PublicHeader(),
   203  			protected: protected,
   204  		}
   205  		_, _, err := sig.Sign(payload, signer.signer, signer.key)
   206  		if err != nil {
   207  			return nil, errors.Wrapf(err, `failed to generate signature for signer #%d (alg=%s)`, i, signer.Algorithm())
   208  		}
   209  
   210  		result.signatures = append(result.signatures, sig)
   211  	}
   212  
   213  	return json.Marshal(result)
   214  }
   215  
   216  type verifyCtx struct {
   217  	dst             *Message
   218  	detachedPayload []byte
   219  	alg             jwa.SignatureAlgorithm
   220  	key             interface{}
   221  	useJKU          bool
   222  	jwksFetcher     JWKSetFetcher
   223  	// This is only used to differentiate compact/JSON serialization
   224  	// because certain features are enabled/disabled in each
   225  	isJSON bool
   226  }
   227  
   228  var allowNoneWhitelist = jwk.WhitelistFunc(func(string) bool {
   229  	return false
   230  })
   231  
   232  // VerifyAuto is a special case of Verify(), where verification is done
   233  // using verifications parameters that can be obtained using the information
   234  // that is carried within the JWS message itself.
   235  //
   236  // Currently it only supports verification via `jku` which will be fetched
   237  // using the object specified in `jws.JWKSetFetcher`. Note that URLs in `jku` can
   238  // only have https scheme.
   239  //
   240  // Using this function will result in your program accessing remote resources via https,
   241  // and therefore extreme caution should be taken which urls can be accessed.
   242  //
   243  // Without specifying extra arguments, the default `jws.JWKSetFetcher` will be
   244  // configured with a whitelist that rejects *ALL URLSs*. This is to
   245  // protect users from unintentionally allowing their projects to
   246  // make unwanted requests. Therefore you must explicitly provide an
   247  // instance of `jwk.Whitelist` that does what you want.
   248  //
   249  // If you want open access to any URLs in the `jku`, you can do this by
   250  // using `jwk.InsecureWhitelist` as the whitelist, but this should be avoided in
   251  // most cases, especially if the payload comes from outside of a controlled
   252  // environment.
   253  //
   254  // It is also advised that you consider using some sort of backoff via `jws.WithFetchBackoff`
   255  //
   256  // Alternatively, you can provide your own `jws.JWKSetFetcher`. In this case
   257  // there is no way for the framework to force you to set a whitelist, so the
   258  // default behavior is to allow any URLs. You are responsible for providing
   259  // your own safety measures.
   260  func VerifyAuto(buf []byte, options ...VerifyOption) ([]byte, error) {
   261  	var ctx verifyCtx
   262  	// enable JKU processing
   263  	ctx.useJKU = true
   264  
   265  	var fetchOptions []jwk.FetchOption
   266  
   267  	//nolint:forcetypeassert
   268  	for _, option := range options {
   269  		switch option.Ident() {
   270  		case identMessage{}:
   271  			ctx.dst = option.Value().(*Message)
   272  		case identDetachedPayload{}:
   273  			ctx.detachedPayload = option.Value().([]byte)
   274  		case identJWKSetFetcher{}:
   275  			ctx.jwksFetcher = option.Value().(JWKSetFetcher)
   276  		case identFetchWhitelist{}:
   277  			fetchOptions = append(fetchOptions, jwk.WithFetchWhitelist(option.Value().(jwk.Whitelist)))
   278  		case identFetchBackoff{}:
   279  			fetchOptions = append(fetchOptions, jwk.WithFetchBackoff(option.Value().(backoff.Policy)))
   280  		case identHTTPClient{}:
   281  			fetchOptions = append(fetchOptions, jwk.WithHTTPClient(option.Value().(*http.Client)))
   282  		}
   283  	}
   284  
   285  	// We shove the default Whitelist in the front of the option list.
   286  	// If the user provided one, it will overwrite our default value
   287  	if ctx.jwksFetcher == nil {
   288  		fetchOptions = append([]jwk.FetchOption{jwk.WithFetchWhitelist(allowNoneWhitelist)}, fetchOptions...)
   289  		ctx.jwksFetcher = NewJWKSetFetcher(fetchOptions...)
   290  	}
   291  
   292  	return ctx.verify(buf)
   293  }
   294  
   295  // Verify checks if the given JWS message is verifiable using `alg` and `key`.
   296  // `key` may be a "raw" key (e.g. rsa.PublicKey) or a jwk.Key
   297  //
   298  // If the verification is successful, `err` is nil, and the content of the
   299  // payload that was signed is returned. If you need more fine-grained
   300  // control of the verification process, manually generate a
   301  // `Verifier` in `verify` subpackage, and call `Verify` method on it.
   302  // If you need to access signatures and JOSE headers in a JWS message,
   303  // use `Parse` function to get `Message` object.
   304  func Verify(buf []byte, alg jwa.SignatureAlgorithm, key interface{}, options ...VerifyOption) ([]byte, error) {
   305  	var ctx verifyCtx
   306  	ctx.alg = alg
   307  	ctx.key = key
   308  	//nolint:forcetypeassert
   309  	for _, option := range options {
   310  		switch option.Ident() {
   311  		case identMessage{}:
   312  			ctx.dst = option.Value().(*Message)
   313  		case identDetachedPayload{}:
   314  			ctx.detachedPayload = option.Value().([]byte)
   315  		default:
   316  			return nil, errors.Errorf(`invalid jws.VerifyOption %q passed`, `With`+strings.TrimPrefix(fmt.Sprintf(`%T`, option.Ident()), `jws.ident`))
   317  		}
   318  	}
   319  
   320  	return ctx.verify(buf)
   321  }
   322  
   323  func (ctx *verifyCtx) verify(buf []byte) ([]byte, error) {
   324  	buf = bytes.TrimSpace(buf)
   325  	if len(buf) == 0 {
   326  		return nil, errors.New(`attempt to verify empty buffer`)
   327  	}
   328  
   329  	if buf[0] == '{' {
   330  		return ctx.verifyJSON(buf)
   331  	}
   332  	return ctx.verifyCompact(buf)
   333  }
   334  
   335  // VerifySet uses keys store in a jwk.Set to verify the payload in `buf`.
   336  //
   337  // In order for `VerifySet()` to use a key in the given set, the
   338  // `jwk.Key` object must have a valid "alg" field, and it also must
   339  // have either an empty value or the value "sig" in the "use" field.
   340  //
   341  // Furthermore if the JWS signature asks for a spefici "kid", the
   342  // `jwk.Key` must have the same "kid" as the signature.
   343  func VerifySet(buf []byte, set jwk.Set) ([]byte, error) {
   344  	n := set.Len()
   345  	for i := 0; i < n; i++ {
   346  		key, ok := set.Get(i)
   347  		if !ok {
   348  			continue
   349  		}
   350  		if key.Algorithm() == "" { // algorithm is not
   351  			continue
   352  		}
   353  
   354  		if usage := key.KeyUsage(); usage != "" && usage != jwk.ForSignature.String() {
   355  			continue
   356  		}
   357  
   358  		buf, err := Verify(buf, jwa.SignatureAlgorithm(key.Algorithm()), key)
   359  		if err != nil {
   360  			continue
   361  		}
   362  
   363  		return buf, nil
   364  	}
   365  
   366  	return nil, errors.New(`failed to verify message with any of the keys in the jwk.Set object`)
   367  }
   368  
   369  func (ctx *verifyCtx) verifyJSON(signed []byte) ([]byte, error) {
   370  	ctx.isJSON = true
   371  
   372  	var m Message
   373  	m.SetDecodeCtx(collectRawCtx{})
   374  	defer m.clearRaw()
   375  	if err := json.Unmarshal(signed, &m); err != nil {
   376  		return nil, errors.Wrap(err, `failed to unmarshal JSON message`)
   377  	}
   378  	m.SetDecodeCtx(nil)
   379  
   380  	if len(m.payload) != 0 && ctx.detachedPayload != nil {
   381  		return nil, errors.New(`can't specify detached payload for JWS with payload`)
   382  	}
   383  
   384  	if ctx.detachedPayload != nil {
   385  		m.payload = ctx.detachedPayload
   386  	}
   387  
   388  	// Pre-compute the base64 encoded version of payload
   389  	var payload string
   390  	if m.b64 {
   391  		payload = base64.EncodeToString(m.payload)
   392  	} else {
   393  		payload = string(m.payload)
   394  	}
   395  
   396  	buf := pool.GetBytesBuffer()
   397  	defer pool.ReleaseBytesBuffer(buf)
   398  
   399  	for i, sig := range m.signatures {
   400  		buf.Reset()
   401  
   402  		var encodedProtectedHeader string
   403  		if rbp, ok := sig.protected.(interface{ rawBuffer() []byte }); ok {
   404  			if raw := rbp.rawBuffer(); raw != nil {
   405  				encodedProtectedHeader = base64.EncodeToString(raw)
   406  			}
   407  		}
   408  
   409  		if encodedProtectedHeader == "" {
   410  			protected, err := json.Marshal(sig.protected)
   411  			if err != nil {
   412  				return nil, errors.Wrapf(err, `failed to marshal "protected" for signature #%d`, i+1)
   413  			}
   414  
   415  			encodedProtectedHeader = base64.EncodeToString(protected)
   416  		}
   417  
   418  		buf.WriteString(encodedProtectedHeader)
   419  		buf.WriteByte('.')
   420  		buf.WriteString(payload)
   421  
   422  		if !ctx.useJKU {
   423  			if hdr := sig.protected; hdr != nil && hdr.KeyID() != "" {
   424  				if jwkKey, ok := ctx.key.(jwk.Key); ok {
   425  					if jwkKey.KeyID() != hdr.KeyID() {
   426  						continue
   427  					}
   428  				}
   429  			}
   430  
   431  			verifier, err := NewVerifier(ctx.alg)
   432  			if err != nil {
   433  				return nil, errors.Wrap(err, "failed to create verifier")
   434  			}
   435  
   436  			if _, err := ctx.tryVerify(verifier, sig.protected, buf.Bytes(), sig.signature, m.payload); err == nil {
   437  				if ctx.dst != nil {
   438  					*(ctx.dst) = m
   439  				}
   440  				return m.payload, nil
   441  			}
   442  			// Don't fallthrough or bail out. Try the next signature.
   443  			continue
   444  		}
   445  
   446  		if _, err := ctx.verifyJKU(sig.protected, buf.Bytes(), sig.signature, m.payload); err == nil {
   447  			if ctx.dst != nil {
   448  				*(ctx.dst) = m
   449  			}
   450  			return m.payload, nil
   451  		}
   452  		// try next
   453  	}
   454  	return nil, errors.New(`could not verify with any of the signatures`)
   455  }
   456  
   457  // get the value of b64 header field.
   458  // If the field does not exist, returns true (default)
   459  // Otherwise return the value specified by the header field.
   460  func getB64Value(hdr Headers) bool {
   461  	b64raw, ok := hdr.Get("b64")
   462  	if !ok {
   463  		return true // default
   464  	}
   465  
   466  	b64, ok := b64raw.(bool) // default
   467  	if !ok {
   468  		return false
   469  	}
   470  	return b64
   471  }
   472  
   473  func (ctx *verifyCtx) verifyCompact(signed []byte) ([]byte, error) {
   474  	protected, payload, signature, err := SplitCompact(signed)
   475  	if err != nil {
   476  		return nil, errors.Wrap(err, `failed extract from compact serialization format`)
   477  	}
   478  
   479  	decodedSignature, err := base64.Decode(signature)
   480  	if err != nil {
   481  		return nil, errors.Wrap(err, `failed to decode signature`)
   482  	}
   483  
   484  	hdr := NewHeaders()
   485  	decodedProtected, err := base64.Decode(protected)
   486  	if err != nil {
   487  		return nil, errors.Wrap(err, `failed to decode headers`)
   488  	}
   489  
   490  	if err := json.Unmarshal(decodedProtected, hdr); err != nil {
   491  		return nil, errors.Wrap(err, `failed to decode headers`)
   492  	}
   493  
   494  	verifyBuf := pool.GetBytesBuffer()
   495  	defer pool.ReleaseBytesBuffer(verifyBuf)
   496  
   497  	verifyBuf.Write(protected)
   498  	verifyBuf.WriteByte('.')
   499  	if len(payload) == 0 && ctx.detachedPayload != nil {
   500  		if getB64Value(hdr) {
   501  			payload = base64.Encode(ctx.detachedPayload)
   502  		} else {
   503  			payload = ctx.detachedPayload
   504  		}
   505  	}
   506  	verifyBuf.Write(payload)
   507  
   508  	if !ctx.useJKU {
   509  		if hdr.KeyID() != "" {
   510  			if jwkKey, ok := ctx.key.(jwk.Key); ok {
   511  				if jwkKey.KeyID() != hdr.KeyID() {
   512  					return nil, errors.New(`"kid" fields do not match`)
   513  				}
   514  			}
   515  		}
   516  
   517  		verifier, err := NewVerifier(ctx.alg)
   518  		if err != nil {
   519  			return nil, errors.Wrap(err, "failed to create verifier")
   520  		}
   521  
   522  		return ctx.tryVerify(verifier, hdr, verifyBuf.Bytes(), decodedSignature, payload)
   523  	}
   524  
   525  	return ctx.verifyJKU(hdr, verifyBuf.Bytes(), decodedSignature, payload)
   526  }
   527  
   528  // JWKSetFetcher is used to fetch JWK Set spcified in the `jku` field.
   529  type JWKSetFetcher interface {
   530  	Fetch(string) (jwk.Set, error)
   531  }
   532  
   533  // SimpleJWKSetFetcher is the default object used to fetch JWK Sets specified in `jku`,
   534  // which uses `jwk.Fetch()`
   535  //
   536  // For more complicated cases, such as using `jwk.AutoRefetch`, you will have to
   537  // create your custom instance of `jws.JWKSetFetcher`
   538  type SimpleJWKSetFetcher struct {
   539  	options []jwk.FetchOption
   540  }
   541  
   542  func NewJWKSetFetcher(options ...jwk.FetchOption) *SimpleJWKSetFetcher {
   543  	return &SimpleJWKSetFetcher{options: options}
   544  }
   545  
   546  func (f *SimpleJWKSetFetcher) Fetch(u string) (jwk.Set, error) {
   547  	return jwk.Fetch(context.TODO(), u, f.options...)
   548  }
   549  
   550  type JWKSetFetchFunc func(string) (jwk.Set, error)
   551  
   552  func (f JWKSetFetchFunc) Fetch(u string) (jwk.Set, error) {
   553  	return f(u)
   554  }
   555  
   556  func (ctx *verifyCtx) verifyJKU(hdr Headers, verifyBuf, decodedSignature, payload []byte) ([]byte, error) {
   557  	u := hdr.JWKSetURL()
   558  	if u == "" {
   559  		return nil, errors.New(`use of "jku" field specified, but the field is empty`)
   560  	}
   561  	uo, err := url.Parse(u)
   562  	if err != nil {
   563  		return nil, errors.Wrap(err, `failed to parse "jku"`)
   564  	}
   565  	if uo.Scheme != "https" {
   566  		return nil, errors.New(`url in "jku" must be HTTPS`)
   567  	}
   568  
   569  	set, err := ctx.jwksFetcher.Fetch(u)
   570  	if err != nil {
   571  		return nil, errors.Wrapf(err, `failed to fetch "jku"`)
   572  	}
   573  
   574  	// Because we're using a JWKS here, we MUST have "kid" that matches
   575  	// the payload
   576  	if hdr.KeyID() == "" {
   577  		return nil, errors.Errorf(`"kid" is required on the JWS message to use "jku"`)
   578  	}
   579  
   580  	key, ok := set.LookupKeyID(hdr.KeyID())
   581  	if !ok {
   582  		return nil, errors.New(`key specified via "kid" is not present in the JWK set specified by "jku"`)
   583  	}
   584  
   585  	// hooray, we found a key. Now the algorithm will have to be inferred.
   586  	algs, err := AlgorithmsForKey(key)
   587  	if err != nil {
   588  		return nil, errors.Wrapf(err, `failed to get a list of signature methods for key type %s`, key.KeyType())
   589  	}
   590  
   591  	// for each of these algorithms, just ... keep trying ...
   592  	ctx.key = key
   593  	hdrAlg := hdr.Algorithm()
   594  	for _, alg := range algs {
   595  		// if we have a "alg" field in the JWS, we can only proceed if
   596  		// the inferred algorithm matches
   597  		if hdrAlg != "" && hdrAlg != alg {
   598  			continue
   599  		}
   600  
   601  		verifier, err := NewVerifier(alg)
   602  		if err != nil {
   603  			return nil, errors.Wrap(err, "failed to create verifier")
   604  		}
   605  
   606  		if decoded, err := ctx.tryVerify(verifier, hdr, verifyBuf, decodedSignature, payload); err == nil {
   607  			return decoded, nil
   608  		}
   609  	}
   610  	return nil, errors.New(`failed to verify payload using key in "jku"`)
   611  }
   612  
   613  func (ctx *verifyCtx) tryVerify(verifier Verifier, hdr Headers, buf, decodedSignature, payload []byte) ([]byte, error) {
   614  	if err := verifier.Verify(buf, decodedSignature, ctx.key); err != nil {
   615  		return nil, errors.Wrap(err, `failed to verify message`)
   616  	}
   617  
   618  	var decodedPayload []byte
   619  
   620  	// When verifying JSON messages, we do not need to decode
   621  	// the payload, as we already have it
   622  	if !ctx.isJSON {
   623  		// This is a special case for RFC7797
   624  		if !getB64Value(hdr) { // it's not base64 encoded
   625  			decodedPayload = payload
   626  		}
   627  
   628  		if decodedPayload == nil {
   629  			v, err := base64.Decode(payload)
   630  			if err != nil {
   631  				return nil, errors.Wrap(err, `message verified, failed to decode payload`)
   632  			}
   633  			decodedPayload = v
   634  		}
   635  
   636  		// For compact serialization, we need to create and assign the message
   637  		// if requested
   638  		if ctx.dst != nil {
   639  			// Construct a new Message object
   640  			m := NewMessage()
   641  			m.SetPayload(decodedPayload)
   642  			sig := NewSignature()
   643  			sig.SetProtectedHeaders(hdr)
   644  			sig.SetSignature(decodedSignature)
   645  			m.AppendSignature(sig)
   646  
   647  			*(ctx.dst) = *m
   648  		}
   649  	}
   650  	return decodedPayload, nil
   651  }
   652  
   653  // This is an "optimized" ioutil.ReadAll(). It will attempt to read
   654  // all of the contents from the reader IF the reader is of a certain
   655  // concrete type.
   656  func readAll(rdr io.Reader) ([]byte, bool) {
   657  	switch rdr.(type) {
   658  	case *bytes.Reader, *bytes.Buffer, *strings.Reader:
   659  		data, err := ioutil.ReadAll(rdr)
   660  		if err != nil {
   661  			return nil, false
   662  		}
   663  		return data, true
   664  	default:
   665  		return nil, false
   666  	}
   667  }
   668  
   669  // Parse parses contents from the given source and creates a jws.Message
   670  // struct. The input can be in either compact or full JSON serialization.
   671  func Parse(src []byte) (*Message, error) {
   672  	for i := 0; i < len(src); i++ {
   673  		r := rune(src[i])
   674  		if r >= utf8.RuneSelf {
   675  			r, _ = utf8.DecodeRune(src)
   676  		}
   677  		if !unicode.IsSpace(r) {
   678  			if r == '{' {
   679  				return parseJSON(src)
   680  			}
   681  			return parseCompact(src)
   682  		}
   683  	}
   684  	return nil, errors.New("invalid byte sequence")
   685  }
   686  
   687  // Parse parses contents from the given source and creates a jws.Message
   688  // struct. The input can be in either compact or full JSON serialization.
   689  func ParseString(src string) (*Message, error) {
   690  	return Parse([]byte(src))
   691  }
   692  
   693  // Parse parses contents from the given source and creates a jws.Message
   694  // struct. The input can be in either compact or full JSON serialization.
   695  func ParseReader(src io.Reader) (*Message, error) {
   696  	if data, ok := readAll(src); ok {
   697  		return Parse(data)
   698  	}
   699  
   700  	rdr := bufio.NewReader(src)
   701  	var first rune
   702  	for {
   703  		r, _, err := rdr.ReadRune()
   704  		if err != nil {
   705  			return nil, errors.Wrap(err, `failed to read rune`)
   706  		}
   707  		if !unicode.IsSpace(r) {
   708  			first = r
   709  			if err := rdr.UnreadRune(); err != nil {
   710  				return nil, errors.Wrap(err, `failed to unread rune`)
   711  			}
   712  
   713  			break
   714  		}
   715  	}
   716  
   717  	var parser func(io.Reader) (*Message, error)
   718  	if first == '{' {
   719  		parser = parseJSONReader
   720  	} else {
   721  		parser = parseCompactReader
   722  	}
   723  
   724  	m, err := parser(rdr)
   725  	if err != nil {
   726  		return nil, errors.Wrap(err, `failed to parse jws message`)
   727  	}
   728  
   729  	return m, nil
   730  }
   731  
   732  func parseJSONReader(src io.Reader) (result *Message, err error) {
   733  	var m Message
   734  	if err := json.NewDecoder(src).Decode(&m); err != nil {
   735  		return nil, errors.Wrap(err, `failed to unmarshal jws message`)
   736  	}
   737  	return &m, nil
   738  }
   739  
   740  func parseJSON(data []byte) (result *Message, err error) {
   741  	var m Message
   742  	if err := json.Unmarshal(data, &m); err != nil {
   743  		return nil, errors.Wrap(err, `failed to unmarshal jws message`)
   744  	}
   745  	return &m, nil
   746  }
   747  
   748  // SplitCompact splits a JWT and returns its three parts
   749  // separately: protected headers, payload and signature.
   750  func SplitCompact(src []byte) ([]byte, []byte, []byte, error) {
   751  	parts := bytes.Split(src, []byte("."))
   752  	if len(parts) < 3 {
   753  		return nil, nil, nil, errors.New(`invalid number of segments`)
   754  	}
   755  	return parts[0], parts[1], parts[2], nil
   756  }
   757  
   758  // SplitCompactString splits a JWT and returns its three parts
   759  // separately: protected headers, payload and signature.
   760  func SplitCompactString(src string) ([]byte, []byte, []byte, error) {
   761  	parts := strings.Split(src, ".")
   762  	if len(parts) < 3 {
   763  		return nil, nil, nil, errors.New(`invalid number of segments`)
   764  	}
   765  	return []byte(parts[0]), []byte(parts[1]), []byte(parts[2]), nil
   766  }
   767  
   768  // SplitCompactReader splits a JWT and returns its three parts
   769  // separately: protected headers, payload and signature.
   770  func SplitCompactReader(rdr io.Reader) ([]byte, []byte, []byte, error) {
   771  	if data, ok := readAll(rdr); ok {
   772  		return SplitCompact(data)
   773  	}
   774  
   775  	var protected []byte
   776  	var payload []byte
   777  	var signature []byte
   778  	var periods int
   779  	var state int
   780  
   781  	buf := make([]byte, 4096)
   782  	var sofar []byte
   783  
   784  	for {
   785  		// read next bytes
   786  		n, err := rdr.Read(buf)
   787  		// return on unexpected read error
   788  		if err != nil && err != io.EOF {
   789  			return nil, nil, nil, errors.Wrap(err, `unexpected end of input`)
   790  		}
   791  
   792  		// append to current buffer
   793  		sofar = append(sofar, buf[:n]...)
   794  		// loop to capture multiple '.' in current buffer
   795  		for loop := true; loop; {
   796  			var i = bytes.IndexByte(sofar, '.')
   797  			if i == -1 && err != io.EOF {
   798  				// no '.' found -> exit and read next bytes (outer loop)
   799  				loop = false
   800  				continue
   801  			} else if i == -1 && err == io.EOF {
   802  				// no '.' found -> process rest and exit
   803  				i = len(sofar)
   804  				loop = false
   805  			} else {
   806  				// '.' found
   807  				periods++
   808  			}
   809  
   810  			// Reaching this point means we have found a '.' or EOF and process the rest of the buffer
   811  			switch state {
   812  			case 0:
   813  				protected = sofar[:i]
   814  				state++
   815  			case 1:
   816  				payload = sofar[:i]
   817  				state++
   818  			case 2:
   819  				signature = sofar[:i]
   820  			}
   821  			// Shorten current buffer
   822  			if len(sofar) > i {
   823  				sofar = sofar[i+1:]
   824  			}
   825  		}
   826  		// Exit on EOF
   827  		if err == io.EOF {
   828  			break
   829  		}
   830  	}
   831  	if periods != 2 {
   832  		return nil, nil, nil, errors.New(`invalid number of segments`)
   833  	}
   834  
   835  	return protected, payload, signature, nil
   836  }
   837  
   838  // parseCompactReader parses a JWS value serialized via compact serialization.
   839  func parseCompactReader(rdr io.Reader) (m *Message, err error) {
   840  	protected, payload, signature, err := SplitCompactReader(rdr)
   841  	if err != nil {
   842  		return nil, errors.Wrap(err, `invalid compact serialization format`)
   843  	}
   844  	return parse(protected, payload, signature)
   845  }
   846  
   847  func parseCompact(data []byte) (m *Message, err error) {
   848  	protected, payload, signature, err := SplitCompact(data)
   849  	if err != nil {
   850  		return nil, errors.Wrap(err, `invalid compact serialization format`)
   851  	}
   852  	return parse(protected, payload, signature)
   853  }
   854  
   855  func parse(protected, payload, signature []byte) (*Message, error) {
   856  	decodedHeader, err := base64.Decode(protected)
   857  	if err != nil {
   858  		return nil, errors.Wrap(err, `failed to decode protected headers`)
   859  	}
   860  
   861  	hdr := NewHeaders()
   862  	if err := json.Unmarshal(decodedHeader, hdr); err != nil {
   863  		return nil, errors.Wrap(err, `failed to parse JOSE headers`)
   864  	}
   865  
   866  	decodedPayload, err := base64.Decode(payload)
   867  	if err != nil {
   868  		return nil, errors.Wrap(err, `failed to decode payload`)
   869  	}
   870  
   871  	decodedSignature, err := base64.Decode(signature)
   872  	if err != nil {
   873  		return nil, errors.Wrap(err, `failed to decode signature`)
   874  	}
   875  
   876  	var msg Message
   877  	msg.payload = decodedPayload
   878  	msg.signatures = append(msg.signatures, &Signature{
   879  		protected: hdr,
   880  		signature: decodedSignature,
   881  	})
   882  	return &msg, nil
   883  }
   884  
   885  // RegisterCustomField allows users to specify that a private field
   886  // be decoded as an instance of the specified type. This option has
   887  // a global effect.
   888  //
   889  // For example, suppose you have a custom field `x-birthday`, which
   890  // you want to represent as a string formatted in RFC3339 in JSON,
   891  // but want it back as `time.Time`.
   892  //
   893  // In that case you would register a custom field as follows
   894  //
   895  //   jwe.RegisterCustomField(`x-birthday`, timeT)
   896  //
   897  // Then `hdr.Get("x-birthday")` will still return an `interface{}`,
   898  // but you can convert its type to `time.Time`
   899  //
   900  //   bdayif, _ := hdr.Get(`x-birthday`)
   901  //   bday := bdayif.(time.Time)
   902  //
   903  func RegisterCustomField(name string, object interface{}) {
   904  	registry.Register(name, object)
   905  }
   906  
   907  // Helpers for signature verification
   908  var rawKeyToKeyType = make(map[reflect.Type]jwa.KeyType)
   909  var keyTypeToAlgorithms = make(map[jwa.KeyType][]jwa.SignatureAlgorithm)
   910  
   911  func init() {
   912  	rawKeyToKeyType[reflect.TypeOf([]byte(nil))] = jwa.OctetSeq
   913  	rawKeyToKeyType[reflect.TypeOf(ed25519.PublicKey(nil))] = jwa.OKP
   914  	rawKeyToKeyType[reflect.TypeOf(rsa.PublicKey{})] = jwa.RSA
   915  	rawKeyToKeyType[reflect.TypeOf((*rsa.PublicKey)(nil))] = jwa.RSA
   916  	rawKeyToKeyType[reflect.TypeOf(ecdsa.PublicKey{})] = jwa.EC
   917  	rawKeyToKeyType[reflect.TypeOf((*ecdsa.PublicKey)(nil))] = jwa.EC
   918  
   919  	addAlgorithmForKeyType(jwa.OKP, jwa.EdDSA)
   920  	for _, alg := range []jwa.SignatureAlgorithm{jwa.HS256, jwa.HS384, jwa.HS512} {
   921  		addAlgorithmForKeyType(jwa.OctetSeq, alg)
   922  	}
   923  	for _, alg := range []jwa.SignatureAlgorithm{jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512} {
   924  		addAlgorithmForKeyType(jwa.RSA, alg)
   925  	}
   926  	for _, alg := range []jwa.SignatureAlgorithm{jwa.ES256, jwa.ES384, jwa.ES512} {
   927  		addAlgorithmForKeyType(jwa.EC, alg)
   928  	}
   929  }
   930  
   931  func addAlgorithmForKeyType(kty jwa.KeyType, alg jwa.SignatureAlgorithm) {
   932  	keyTypeToAlgorithms[kty] = append(keyTypeToAlgorithms[kty], alg)
   933  }
   934  
   935  // AlgorithmsForKey returns the possible signature algorithms that can
   936  // be used for a given key. It only takes in consideration keys/algorithms
   937  // for verification purposes, as this is the only usage where one may need
   938  // dynamically figure out which method to use.
   939  func AlgorithmsForKey(key interface{}) ([]jwa.SignatureAlgorithm, error) {
   940  	var kty jwa.KeyType
   941  	switch key := key.(type) {
   942  	case jwk.Key:
   943  		kty = key.KeyType()
   944  	case rsa.PublicKey, *rsa.PublicKey, rsa.PrivateKey, *rsa.PrivateKey:
   945  		kty = jwa.RSA
   946  	case ecdsa.PublicKey, *ecdsa.PublicKey, ecdsa.PrivateKey, *ecdsa.PrivateKey:
   947  		kty = jwa.EC
   948  	case ed25519.PublicKey, ed25519.PrivateKey, x25519.PublicKey, x25519.PrivateKey:
   949  		kty = jwa.OKP
   950  	case []byte:
   951  		kty = jwa.OctetSeq
   952  	default:
   953  		return nil, errors.Errorf(`invalid key %T`, key)
   954  	}
   955  
   956  	algs, ok := keyTypeToAlgorithms[kty]
   957  	if !ok {
   958  		return nil, errors.Errorf(`invalid key type %q`, kty)
   959  	}
   960  	return algs, nil
   961  }
   962  

View as plain text