...

Source file src/github.com/lestrrat-go/jwx/jwe/decrypt.go

Documentation: github.com/lestrrat-go/jwx/jwe

     1  package jwe
     2  
     3  import (
     4  	"crypto/aes"
     5  	cryptocipher "crypto/cipher"
     6  	"crypto/ecdsa"
     7  	"crypto/rsa"
     8  	"crypto/sha256"
     9  	"crypto/sha512"
    10  	"hash"
    11  
    12  	"golang.org/x/crypto/pbkdf2"
    13  
    14  	"github.com/lestrrat-go/jwx/internal/keyconv"
    15  	"github.com/lestrrat-go/jwx/jwa"
    16  	"github.com/lestrrat-go/jwx/jwe/internal/cipher"
    17  	"github.com/lestrrat-go/jwx/jwe/internal/content_crypt"
    18  	"github.com/lestrrat-go/jwx/jwe/internal/keyenc"
    19  	"github.com/lestrrat-go/jwx/x25519"
    20  	"github.com/pkg/errors"
    21  )
    22  
    23  // Decrypter is responsible for taking various components to decrypt a message.
    24  // its operation is not concurrency safe. You must provide locking yourself
    25  //nolint:govet
    26  type Decrypter struct {
    27  	aad         []byte
    28  	apu         []byte
    29  	apv         []byte
    30  	computedAad []byte
    31  	iv          []byte
    32  	keyiv       []byte
    33  	keysalt     []byte
    34  	keytag      []byte
    35  	tag         []byte
    36  	privkey     interface{}
    37  	pubkey      interface{}
    38  	ctalg       jwa.ContentEncryptionAlgorithm
    39  	keyalg      jwa.KeyEncryptionAlgorithm
    40  	cipher      content_crypt.Cipher
    41  	keycount    int
    42  }
    43  
    44  // NewDecrypter Creates a new Decrypter instance. You must supply the
    45  // rest of parameters via their respective setter methods before
    46  // calling Decrypt().
    47  //
    48  // privkey must be a private key in its "raw" format (i.e. something like
    49  // *rsa.PrivateKey, instead of jwk.Key)
    50  //
    51  // You should consider this object immutable once you assign values to it.
    52  func NewDecrypter(keyalg jwa.KeyEncryptionAlgorithm, ctalg jwa.ContentEncryptionAlgorithm, privkey interface{}) *Decrypter {
    53  	return &Decrypter{
    54  		ctalg:   ctalg,
    55  		keyalg:  keyalg,
    56  		privkey: privkey,
    57  	}
    58  }
    59  
    60  func (d *Decrypter) AgreementPartyUInfo(apu []byte) *Decrypter {
    61  	d.apu = apu
    62  	return d
    63  }
    64  
    65  func (d *Decrypter) AgreementPartyVInfo(apv []byte) *Decrypter {
    66  	d.apv = apv
    67  	return d
    68  }
    69  
    70  func (d *Decrypter) AuthenticatedData(aad []byte) *Decrypter {
    71  	d.aad = aad
    72  	return d
    73  }
    74  
    75  func (d *Decrypter) ComputedAuthenticatedData(aad []byte) *Decrypter {
    76  	d.computedAad = aad
    77  	return d
    78  }
    79  
    80  func (d *Decrypter) ContentEncryptionAlgorithm(ctalg jwa.ContentEncryptionAlgorithm) *Decrypter {
    81  	d.ctalg = ctalg
    82  	return d
    83  }
    84  
    85  func (d *Decrypter) InitializationVector(iv []byte) *Decrypter {
    86  	d.iv = iv
    87  	return d
    88  }
    89  
    90  func (d *Decrypter) KeyCount(keycount int) *Decrypter {
    91  	d.keycount = keycount
    92  	return d
    93  }
    94  
    95  func (d *Decrypter) KeyInitializationVector(keyiv []byte) *Decrypter {
    96  	d.keyiv = keyiv
    97  	return d
    98  }
    99  
   100  func (d *Decrypter) KeySalt(keysalt []byte) *Decrypter {
   101  	d.keysalt = keysalt
   102  	return d
   103  }
   104  
   105  func (d *Decrypter) KeyTag(keytag []byte) *Decrypter {
   106  	d.keytag = keytag
   107  	return d
   108  }
   109  
   110  // PublicKey sets the public key to be used in decoding EC based encryptions.
   111  // The key must be in its "raw" format (i.e. *ecdsa.PublicKey, instead of jwk.Key)
   112  func (d *Decrypter) PublicKey(pubkey interface{}) *Decrypter {
   113  	d.pubkey = pubkey
   114  	return d
   115  }
   116  
   117  func (d *Decrypter) Tag(tag []byte) *Decrypter {
   118  	d.tag = tag
   119  	return d
   120  }
   121  
   122  func (d *Decrypter) ContentCipher() (content_crypt.Cipher, error) {
   123  	if d.cipher == nil {
   124  		switch d.ctalg {
   125  		case jwa.A128GCM, jwa.A192GCM, jwa.A256GCM, jwa.A128CBC_HS256, jwa.A192CBC_HS384, jwa.A256CBC_HS512:
   126  			cipher, err := cipher.NewAES(d.ctalg)
   127  			if err != nil {
   128  				return nil, errors.Wrapf(err, `failed to build content cipher for %s`, d.ctalg)
   129  			}
   130  			d.cipher = cipher
   131  		default:
   132  			return nil, errors.Errorf(`invalid content cipher algorithm (%s)`, d.ctalg)
   133  		}
   134  	}
   135  
   136  	return d.cipher, nil
   137  }
   138  
   139  func (d *Decrypter) Decrypt(recipientKey, ciphertext []byte) (plaintext []byte, err error) {
   140  	cek, keyerr := d.DecryptKey(recipientKey)
   141  	if keyerr != nil {
   142  		err = errors.Wrap(keyerr, `failed to decrypt key`)
   143  		return
   144  	}
   145  
   146  	cipher, ciphererr := d.ContentCipher()
   147  	if ciphererr != nil {
   148  		err = errors.Wrap(ciphererr, `failed to fetch content crypt cipher`)
   149  		return
   150  	}
   151  
   152  	computedAad := d.computedAad
   153  	if d.aad != nil {
   154  		computedAad = append(append(computedAad, '.'), d.aad...)
   155  	}
   156  
   157  	plaintext, err = cipher.Decrypt(cek, d.iv, ciphertext, d.tag, computedAad)
   158  	if err != nil {
   159  		err = errors.Wrap(err, `failed to decrypt payload`)
   160  		return
   161  	}
   162  
   163  	return plaintext, nil
   164  }
   165  
   166  func (d *Decrypter) decryptSymmetricKey(recipientKey, cek []byte) ([]byte, error) {
   167  	switch d.keyalg {
   168  	case jwa.DIRECT:
   169  		return cek, nil
   170  	case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
   171  		var hashFunc func() hash.Hash
   172  		var keylen int
   173  		switch d.keyalg {
   174  		case jwa.PBES2_HS256_A128KW:
   175  			hashFunc = sha256.New
   176  			keylen = 16
   177  		case jwa.PBES2_HS384_A192KW:
   178  			hashFunc = sha512.New384
   179  			keylen = 24
   180  		case jwa.PBES2_HS512_A256KW:
   181  			hashFunc = sha512.New
   182  			keylen = 32
   183  		}
   184  		salt := []byte(d.keyalg)
   185  		salt = append(salt, byte(0))
   186  		salt = append(salt, d.keysalt...)
   187  		cek = pbkdf2.Key(cek, salt, d.keycount, keylen, hashFunc)
   188  		fallthrough
   189  	case jwa.A128KW, jwa.A192KW, jwa.A256KW:
   190  		block, err := aes.NewCipher(cek)
   191  		if err != nil {
   192  			return nil, errors.Wrap(err, `failed to create new AES cipher`)
   193  		}
   194  
   195  		jek, err := keyenc.Unwrap(block, recipientKey)
   196  		if err != nil {
   197  			return nil, errors.Wrap(err, `failed to unwrap key`)
   198  		}
   199  
   200  		return jek, nil
   201  	case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW:
   202  		if len(d.keyiv) != 12 {
   203  			return nil, errors.Errorf("GCM requires 96-bit iv, got %d", len(d.keyiv)*8)
   204  		}
   205  		if len(d.keytag) != 16 {
   206  			return nil, errors.Errorf("GCM requires 128-bit tag, got %d", len(d.keytag)*8)
   207  		}
   208  		block, err := aes.NewCipher(cek)
   209  		if err != nil {
   210  			return nil, errors.Wrap(err, `failed to create new AES cipher`)
   211  		}
   212  		aesgcm, err := cryptocipher.NewGCM(block)
   213  		if err != nil {
   214  			return nil, errors.Wrap(err, `failed to create new GCM wrap`)
   215  		}
   216  		ciphertext := recipientKey[:]
   217  		ciphertext = append(ciphertext, d.keytag...)
   218  		jek, err := aesgcm.Open(nil, d.keyiv, ciphertext, nil)
   219  		if err != nil {
   220  			return nil, errors.Wrap(err, `failed to decode key`)
   221  		}
   222  		return jek, nil
   223  	default:
   224  		return nil, errors.Errorf("decrypt key: unsupported algorithm %s", d.keyalg)
   225  	}
   226  }
   227  
   228  func (d *Decrypter) DecryptKey(recipientKey []byte) (cek []byte, err error) {
   229  	if d.keyalg.IsSymmetric() {
   230  		var ok bool
   231  		cek, ok = d.privkey.([]byte)
   232  		if !ok {
   233  			return nil, errors.Errorf("decrypt key: []byte is required as the key to build %s key decrypter (got %T)", d.keyalg, d.privkey)
   234  		}
   235  
   236  		return d.decryptSymmetricKey(recipientKey, cek)
   237  	}
   238  
   239  	k, err := d.BuildKeyDecrypter()
   240  	if err != nil {
   241  		return nil, errors.Wrap(err, `failed to build key decrypter`)
   242  	}
   243  
   244  	cek, err = k.Decrypt(recipientKey)
   245  	if err != nil {
   246  		return nil, errors.Wrap(err, `failed to decrypt key`)
   247  	}
   248  
   249  	return cek, nil
   250  }
   251  
   252  func (d *Decrypter) BuildKeyDecrypter() (keyenc.Decrypter, error) {
   253  	cipher, err := d.ContentCipher()
   254  	if err != nil {
   255  		return nil, errors.Wrap(err, `failed to fetch content crypt cipher`)
   256  	}
   257  
   258  	switch alg := d.keyalg; alg {
   259  	case jwa.RSA1_5:
   260  		var privkey rsa.PrivateKey
   261  		if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil {
   262  			return nil, errors.Wrapf(err, "*rsa.PrivateKey is required as the key to build %s key decrypter", alg)
   263  		}
   264  
   265  		return keyenc.NewRSAPKCS15Decrypt(alg, &privkey, cipher.KeySize()/2), nil
   266  	case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
   267  		var privkey rsa.PrivateKey
   268  		if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil {
   269  			return nil, errors.Wrapf(err, "*rsa.PrivateKey is required as the key to build %s key decrypter", alg)
   270  		}
   271  
   272  		return keyenc.NewRSAOAEPDecrypt(alg, &privkey)
   273  	case jwa.A128KW, jwa.A192KW, jwa.A256KW:
   274  		sharedkey, ok := d.privkey.([]byte)
   275  		if !ok {
   276  			return nil, errors.Errorf("[]byte is required as the key to build %s key decrypter", alg)
   277  		}
   278  
   279  		return keyenc.NewAES(alg, sharedkey)
   280  	case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
   281  		switch d.pubkey.(type) {
   282  		case x25519.PublicKey:
   283  			return keyenc.NewECDHESDecrypt(alg, d.ctalg, d.pubkey, d.apu, d.apv, d.privkey), nil
   284  		default:
   285  			var pubkey ecdsa.PublicKey
   286  			if err := keyconv.ECDSAPublicKey(&pubkey, d.pubkey); err != nil {
   287  				return nil, errors.Wrapf(err, "*ecdsa.PublicKey is required as the key to build %s key decrypter", alg)
   288  			}
   289  
   290  			var privkey ecdsa.PrivateKey
   291  			if err := keyconv.ECDSAPrivateKey(&privkey, d.privkey); err != nil {
   292  				return nil, errors.Wrapf(err, "*ecdsa.PrivateKey is required as the key to build %s key decrypter", alg)
   293  			}
   294  
   295  			return keyenc.NewECDHESDecrypt(alg, d.ctalg, &pubkey, d.apu, d.apv, &privkey), nil
   296  		}
   297  	default:
   298  		return nil, errors.Errorf(`unsupported algorithm for key decryption (%s)`, alg)
   299  	}
   300  }
   301  

View as plain text