...

Source file src/github.com/lestrrat-go/jwx/jws/rsa.go

Documentation: github.com/lestrrat-go/jwx/jws

     1  package jws
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/rand"
     6  	"crypto/rsa"
     7  
     8  	"github.com/lestrrat-go/jwx/internal/keyconv"
     9  	"github.com/lestrrat-go/jwx/jwa"
    10  	"github.com/pkg/errors"
    11  )
    12  
    13  var rsaSigners map[jwa.SignatureAlgorithm]*rsaSigner
    14  var rsaVerifiers map[jwa.SignatureAlgorithm]*rsaVerifier
    15  
    16  func init() {
    17  	algs := map[jwa.SignatureAlgorithm]struct {
    18  		Hash crypto.Hash
    19  		PSS  bool
    20  	}{
    21  		jwa.RS256: {
    22  			Hash: crypto.SHA256,
    23  		},
    24  		jwa.RS384: {
    25  			Hash: crypto.SHA384,
    26  		},
    27  		jwa.RS512: {
    28  			Hash: crypto.SHA512,
    29  		},
    30  		jwa.PS256: {
    31  			Hash: crypto.SHA256,
    32  			PSS:  true,
    33  		},
    34  		jwa.PS384: {
    35  			Hash: crypto.SHA384,
    36  			PSS:  true,
    37  		},
    38  		jwa.PS512: {
    39  			Hash: crypto.SHA512,
    40  			PSS:  true,
    41  		},
    42  	}
    43  
    44  	rsaSigners = make(map[jwa.SignatureAlgorithm]*rsaSigner)
    45  	rsaVerifiers = make(map[jwa.SignatureAlgorithm]*rsaVerifier)
    46  	for alg, item := range algs {
    47  		rsaSigners[alg] = &rsaSigner{
    48  			alg:  alg,
    49  			hash: item.Hash,
    50  			pss:  item.PSS,
    51  		}
    52  		rsaVerifiers[alg] = &rsaVerifier{
    53  			alg:  alg,
    54  			hash: item.Hash,
    55  			pss:  item.PSS,
    56  		}
    57  	}
    58  }
    59  
    60  type rsaSigner struct {
    61  	alg  jwa.SignatureAlgorithm
    62  	hash crypto.Hash
    63  	pss  bool
    64  }
    65  
    66  func newRSASigner(alg jwa.SignatureAlgorithm) Signer {
    67  	return rsaSigners[alg]
    68  }
    69  
    70  func (rs *rsaSigner) Algorithm() jwa.SignatureAlgorithm {
    71  	return rs.alg
    72  }
    73  
    74  func (rs *rsaSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
    75  	if key == nil {
    76  		return nil, errors.New(`missing private key while signing payload`)
    77  	}
    78  
    79  	signer, ok := key.(crypto.Signer)
    80  	if !ok {
    81  		var privkey rsa.PrivateKey
    82  		if err := keyconv.RSAPrivateKey(&privkey, key); err != nil {
    83  			return nil, errors.Wrapf(err, `failed to retrieve rsa.PrivateKey out of %T`, key)
    84  		}
    85  		signer = &privkey
    86  	}
    87  
    88  	h := rs.hash.New()
    89  	if _, err := h.Write(payload); err != nil {
    90  		return nil, errors.Wrap(err, "failed to write payload to hash")
    91  	}
    92  	if rs.pss {
    93  		return signer.Sign(rand.Reader, h.Sum(nil), &rsa.PSSOptions{
    94  			Hash:       rs.hash,
    95  			SaltLength: rsa.PSSSaltLengthEqualsHash,
    96  		})
    97  	}
    98  	return signer.Sign(rand.Reader, h.Sum(nil), rs.hash)
    99  }
   100  
   101  type rsaVerifier struct {
   102  	alg  jwa.SignatureAlgorithm
   103  	hash crypto.Hash
   104  	pss  bool
   105  }
   106  
   107  func newRSAVerifier(alg jwa.SignatureAlgorithm) Verifier {
   108  	return rsaVerifiers[alg]
   109  }
   110  
   111  func (rv *rsaVerifier) Verify(payload, signature []byte, key interface{}) error {
   112  	if key == nil {
   113  		return errors.New(`missing public key while verifying payload`)
   114  	}
   115  
   116  	var pubkey rsa.PublicKey
   117  	if cs, ok := key.(crypto.Signer); ok {
   118  		cpub := cs.Public()
   119  		switch cpub := cpub.(type) {
   120  		case rsa.PublicKey:
   121  			pubkey = cpub
   122  		case *rsa.PublicKey:
   123  			pubkey = *cpub
   124  		default:
   125  			return errors.Errorf(`failed to retrieve rsa.PublicKey out of crypto.Signer %T`, key)
   126  		}
   127  	} else {
   128  		if err := keyconv.RSAPublicKey(&pubkey, key); err != nil {
   129  			return errors.Wrapf(err, `failed to retrieve rsa.PublicKey out of %T`, key)
   130  		}
   131  	}
   132  
   133  	h := rv.hash.New()
   134  	if _, err := h.Write(payload); err != nil {
   135  		return errors.Wrap(err, "failed to write payload to hash")
   136  	}
   137  
   138  	if rv.pss {
   139  		return rsa.VerifyPSS(&pubkey, rv.hash, h.Sum(nil), signature, nil)
   140  	}
   141  	return rsa.VerifyPKCS1v15(&pubkey, rv.hash, h.Sum(nil), signature)
   142  }
   143  

View as plain text