...

Source file src/gopkg.in/square/go-jose.v2/crypter.go

Documentation: gopkg.in/square/go-jose.v2

     1  /*-
     2   * Copyright 2014 Square Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package jose
    18  
    19  import (
    20  	"crypto/ecdsa"
    21  	"crypto/rsa"
    22  	"errors"
    23  	"fmt"
    24  	"reflect"
    25  
    26  	"gopkg.in/square/go-jose.v2/json"
    27  )
    28  
    29  // Encrypter represents an encrypter which produces an encrypted JWE object.
    30  type Encrypter interface {
    31  	Encrypt(plaintext []byte) (*JSONWebEncryption, error)
    32  	EncryptWithAuthData(plaintext []byte, aad []byte) (*JSONWebEncryption, error)
    33  	Options() EncrypterOptions
    34  }
    35  
    36  // A generic content cipher
    37  type contentCipher interface {
    38  	keySize() int
    39  	encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error)
    40  	decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error)
    41  }
    42  
    43  // A key generator (for generating/getting a CEK)
    44  type keyGenerator interface {
    45  	keySize() int
    46  	genKey() ([]byte, rawHeader, error)
    47  }
    48  
    49  // A generic key encrypter
    50  type keyEncrypter interface {
    51  	encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) // Encrypt a key
    52  }
    53  
    54  // A generic key decrypter
    55  type keyDecrypter interface {
    56  	decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) // Decrypt a key
    57  }
    58  
    59  // A generic encrypter based on the given key encrypter and content cipher.
    60  type genericEncrypter struct {
    61  	contentAlg     ContentEncryption
    62  	compressionAlg CompressionAlgorithm
    63  	cipher         contentCipher
    64  	recipients     []recipientKeyInfo
    65  	keyGenerator   keyGenerator
    66  	extraHeaders   map[HeaderKey]interface{}
    67  }
    68  
    69  type recipientKeyInfo struct {
    70  	keyID        string
    71  	keyAlg       KeyAlgorithm
    72  	keyEncrypter keyEncrypter
    73  }
    74  
    75  // EncrypterOptions represents options that can be set on new encrypters.
    76  type EncrypterOptions struct {
    77  	Compression CompressionAlgorithm
    78  
    79  	// Optional map of additional keys to be inserted into the protected header
    80  	// of a JWS object. Some specifications which make use of JWS like to insert
    81  	// additional values here. All values must be JSON-serializable.
    82  	ExtraHeaders map[HeaderKey]interface{}
    83  }
    84  
    85  // WithHeader adds an arbitrary value to the ExtraHeaders map, initializing it
    86  // if necessary. It returns itself and so can be used in a fluent style.
    87  func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions {
    88  	if eo.ExtraHeaders == nil {
    89  		eo.ExtraHeaders = map[HeaderKey]interface{}{}
    90  	}
    91  	eo.ExtraHeaders[k] = v
    92  	return eo
    93  }
    94  
    95  // WithContentType adds a content type ("cty") header and returns the updated
    96  // EncrypterOptions.
    97  func (eo *EncrypterOptions) WithContentType(contentType ContentType) *EncrypterOptions {
    98  	return eo.WithHeader(HeaderContentType, contentType)
    99  }
   100  
   101  // WithType adds a type ("typ") header and returns the updated EncrypterOptions.
   102  func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions {
   103  	return eo.WithHeader(HeaderType, typ)
   104  }
   105  
   106  // Recipient represents an algorithm/key to encrypt messages to.
   107  //
   108  // PBES2Count and PBES2Salt correspond with the  "p2c" and "p2s" headers used
   109  // on the password-based encryption algorithms PBES2-HS256+A128KW,
   110  // PBES2-HS384+A192KW, and PBES2-HS512+A256KW. If they are not provided a safe
   111  // default of 100000 will be used for the count and a 128-bit random salt will
   112  // be generated.
   113  type Recipient struct {
   114  	Algorithm  KeyAlgorithm
   115  	Key        interface{}
   116  	KeyID      string
   117  	PBES2Count int
   118  	PBES2Salt  []byte
   119  }
   120  
   121  // NewEncrypter creates an appropriate encrypter based on the key type
   122  func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions) (Encrypter, error) {
   123  	encrypter := &genericEncrypter{
   124  		contentAlg: enc,
   125  		recipients: []recipientKeyInfo{},
   126  		cipher:     getContentCipher(enc),
   127  	}
   128  	if opts != nil {
   129  		encrypter.compressionAlg = opts.Compression
   130  		encrypter.extraHeaders = opts.ExtraHeaders
   131  	}
   132  
   133  	if encrypter.cipher == nil {
   134  		return nil, ErrUnsupportedAlgorithm
   135  	}
   136  
   137  	var keyID string
   138  	var rawKey interface{}
   139  	switch encryptionKey := rcpt.Key.(type) {
   140  	case JSONWebKey:
   141  		keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
   142  	case *JSONWebKey:
   143  		keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
   144  	case OpaqueKeyEncrypter:
   145  		keyID, rawKey = encryptionKey.KeyID(), encryptionKey
   146  	default:
   147  		rawKey = encryptionKey
   148  	}
   149  
   150  	switch rcpt.Algorithm {
   151  	case DIRECT:
   152  		// Direct encryption mode must be treated differently
   153  		if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) {
   154  			return nil, ErrUnsupportedKeyType
   155  		}
   156  		if encrypter.cipher.keySize() != len(rawKey.([]byte)) {
   157  			return nil, ErrInvalidKeySize
   158  		}
   159  		encrypter.keyGenerator = staticKeyGenerator{
   160  			key: rawKey.([]byte),
   161  		}
   162  		recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte))
   163  		recipientInfo.keyID = keyID
   164  		if rcpt.KeyID != "" {
   165  			recipientInfo.keyID = rcpt.KeyID
   166  		}
   167  		encrypter.recipients = []recipientKeyInfo{recipientInfo}
   168  		return encrypter, nil
   169  	case ECDH_ES:
   170  		// ECDH-ES (w/o key wrapping) is similar to DIRECT mode
   171  		typeOf := reflect.TypeOf(rawKey)
   172  		if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) {
   173  			return nil, ErrUnsupportedKeyType
   174  		}
   175  		encrypter.keyGenerator = ecKeyGenerator{
   176  			size:      encrypter.cipher.keySize(),
   177  			algID:     string(enc),
   178  			publicKey: rawKey.(*ecdsa.PublicKey),
   179  		}
   180  		recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey))
   181  		recipientInfo.keyID = keyID
   182  		if rcpt.KeyID != "" {
   183  			recipientInfo.keyID = rcpt.KeyID
   184  		}
   185  		encrypter.recipients = []recipientKeyInfo{recipientInfo}
   186  		return encrypter, nil
   187  	default:
   188  		// Can just add a standard recipient
   189  		encrypter.keyGenerator = randomKeyGenerator{
   190  			size: encrypter.cipher.keySize(),
   191  		}
   192  		err := encrypter.addRecipient(rcpt)
   193  		return encrypter, err
   194  	}
   195  }
   196  
   197  // NewMultiEncrypter creates a multi-encrypter based on the given parameters
   198  func NewMultiEncrypter(enc ContentEncryption, rcpts []Recipient, opts *EncrypterOptions) (Encrypter, error) {
   199  	cipher := getContentCipher(enc)
   200  
   201  	if cipher == nil {
   202  		return nil, ErrUnsupportedAlgorithm
   203  	}
   204  	if rcpts == nil || len(rcpts) == 0 {
   205  		return nil, fmt.Errorf("square/go-jose: recipients is nil or empty")
   206  	}
   207  
   208  	encrypter := &genericEncrypter{
   209  		contentAlg: enc,
   210  		recipients: []recipientKeyInfo{},
   211  		cipher:     cipher,
   212  		keyGenerator: randomKeyGenerator{
   213  			size: cipher.keySize(),
   214  		},
   215  	}
   216  
   217  	if opts != nil {
   218  		encrypter.compressionAlg = opts.Compression
   219  		encrypter.extraHeaders = opts.ExtraHeaders
   220  	}
   221  
   222  	for _, recipient := range rcpts {
   223  		err := encrypter.addRecipient(recipient)
   224  		if err != nil {
   225  			return nil, err
   226  		}
   227  	}
   228  
   229  	return encrypter, nil
   230  }
   231  
   232  func (ctx *genericEncrypter) addRecipient(recipient Recipient) (err error) {
   233  	var recipientInfo recipientKeyInfo
   234  
   235  	switch recipient.Algorithm {
   236  	case DIRECT, ECDH_ES:
   237  		return fmt.Errorf("square/go-jose: key algorithm '%s' not supported in multi-recipient mode", recipient.Algorithm)
   238  	}
   239  
   240  	recipientInfo, err = makeJWERecipient(recipient.Algorithm, recipient.Key)
   241  	if recipient.KeyID != "" {
   242  		recipientInfo.keyID = recipient.KeyID
   243  	}
   244  
   245  	switch recipient.Algorithm {
   246  	case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
   247  		if sr, ok := recipientInfo.keyEncrypter.(*symmetricKeyCipher); ok {
   248  			sr.p2c = recipient.PBES2Count
   249  			sr.p2s = recipient.PBES2Salt
   250  		}
   251  	}
   252  
   253  	if err == nil {
   254  		ctx.recipients = append(ctx.recipients, recipientInfo)
   255  	}
   256  	return err
   257  }
   258  
   259  func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) {
   260  	switch encryptionKey := encryptionKey.(type) {
   261  	case *rsa.PublicKey:
   262  		return newRSARecipient(alg, encryptionKey)
   263  	case *ecdsa.PublicKey:
   264  		return newECDHRecipient(alg, encryptionKey)
   265  	case []byte:
   266  		return newSymmetricRecipient(alg, encryptionKey)
   267  	case string:
   268  		return newSymmetricRecipient(alg, []byte(encryptionKey))
   269  	case *JSONWebKey:
   270  		recipient, err := makeJWERecipient(alg, encryptionKey.Key)
   271  		recipient.keyID = encryptionKey.KeyID
   272  		return recipient, err
   273  	}
   274  	if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok {
   275  		return newOpaqueKeyEncrypter(alg, encrypter)
   276  	}
   277  	return recipientKeyInfo{}, ErrUnsupportedKeyType
   278  }
   279  
   280  // newDecrypter creates an appropriate decrypter based on the key type
   281  func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
   282  	switch decryptionKey := decryptionKey.(type) {
   283  	case *rsa.PrivateKey:
   284  		return &rsaDecrypterSigner{
   285  			privateKey: decryptionKey,
   286  		}, nil
   287  	case *ecdsa.PrivateKey:
   288  		return &ecDecrypterSigner{
   289  			privateKey: decryptionKey,
   290  		}, nil
   291  	case []byte:
   292  		return &symmetricKeyCipher{
   293  			key: decryptionKey,
   294  		}, nil
   295  	case string:
   296  		return &symmetricKeyCipher{
   297  			key: []byte(decryptionKey),
   298  		}, nil
   299  	case JSONWebKey:
   300  		return newDecrypter(decryptionKey.Key)
   301  	case *JSONWebKey:
   302  		return newDecrypter(decryptionKey.Key)
   303  	}
   304  	if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok {
   305  		return &opaqueKeyDecrypter{decrypter: okd}, nil
   306  	}
   307  	return nil, ErrUnsupportedKeyType
   308  }
   309  
   310  // Implementation of encrypt method producing a JWE object.
   311  func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) {
   312  	return ctx.EncryptWithAuthData(plaintext, nil)
   313  }
   314  
   315  // Implementation of encrypt method producing a JWE object.
   316  func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) {
   317  	obj := &JSONWebEncryption{}
   318  	obj.aad = aad
   319  
   320  	obj.protected = &rawHeader{}
   321  	err := obj.protected.set(headerEncryption, ctx.contentAlg)
   322  	if err != nil {
   323  		return nil, err
   324  	}
   325  
   326  	obj.recipients = make([]recipientInfo, len(ctx.recipients))
   327  
   328  	if len(ctx.recipients) == 0 {
   329  		return nil, fmt.Errorf("square/go-jose: no recipients to encrypt to")
   330  	}
   331  
   332  	cek, headers, err := ctx.keyGenerator.genKey()
   333  	if err != nil {
   334  		return nil, err
   335  	}
   336  
   337  	obj.protected.merge(&headers)
   338  
   339  	for i, info := range ctx.recipients {
   340  		recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg)
   341  		if err != nil {
   342  			return nil, err
   343  		}
   344  
   345  		err = recipient.header.set(headerAlgorithm, info.keyAlg)
   346  		if err != nil {
   347  			return nil, err
   348  		}
   349  
   350  		if info.keyID != "" {
   351  			err = recipient.header.set(headerKeyID, info.keyID)
   352  			if err != nil {
   353  				return nil, err
   354  			}
   355  		}
   356  		obj.recipients[i] = recipient
   357  	}
   358  
   359  	if len(ctx.recipients) == 1 {
   360  		// Move per-recipient headers into main protected header if there's
   361  		// only a single recipient.
   362  		obj.protected.merge(obj.recipients[0].header)
   363  		obj.recipients[0].header = nil
   364  	}
   365  
   366  	if ctx.compressionAlg != NONE {
   367  		plaintext, err = compress(ctx.compressionAlg, plaintext)
   368  		if err != nil {
   369  			return nil, err
   370  		}
   371  
   372  		err = obj.protected.set(headerCompression, ctx.compressionAlg)
   373  		if err != nil {
   374  			return nil, err
   375  		}
   376  	}
   377  
   378  	for k, v := range ctx.extraHeaders {
   379  		b, err := json.Marshal(v)
   380  		if err != nil {
   381  			return nil, err
   382  		}
   383  		(*obj.protected)[k] = makeRawMessage(b)
   384  	}
   385  
   386  	authData := obj.computeAuthData()
   387  	parts, err := ctx.cipher.encrypt(cek, authData, plaintext)
   388  	if err != nil {
   389  		return nil, err
   390  	}
   391  
   392  	obj.iv = parts.iv
   393  	obj.ciphertext = parts.ciphertext
   394  	obj.tag = parts.tag
   395  
   396  	return obj, nil
   397  }
   398  
   399  func (ctx *genericEncrypter) Options() EncrypterOptions {
   400  	return EncrypterOptions{
   401  		Compression:  ctx.compressionAlg,
   402  		ExtraHeaders: ctx.extraHeaders,
   403  	}
   404  }
   405  
   406  // Decrypt and validate the object and return the plaintext. Note that this
   407  // function does not support multi-recipient, if you desire multi-recipient
   408  // decryption use DecryptMulti instead.
   409  func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
   410  	headers := obj.mergedHeaders(nil)
   411  
   412  	if len(obj.recipients) > 1 {
   413  		return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one")
   414  	}
   415  
   416  	critical, err := headers.getCritical()
   417  	if err != nil {
   418  		return nil, fmt.Errorf("square/go-jose: invalid crit header")
   419  	}
   420  
   421  	if len(critical) > 0 {
   422  		return nil, fmt.Errorf("square/go-jose: unsupported crit header")
   423  	}
   424  
   425  	decrypter, err := newDecrypter(decryptionKey)
   426  	if err != nil {
   427  		return nil, err
   428  	}
   429  
   430  	cipher := getContentCipher(headers.getEncryption())
   431  	if cipher == nil {
   432  		return nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(headers.getEncryption()))
   433  	}
   434  
   435  	generator := randomKeyGenerator{
   436  		size: cipher.keySize(),
   437  	}
   438  
   439  	parts := &aeadParts{
   440  		iv:         obj.iv,
   441  		ciphertext: obj.ciphertext,
   442  		tag:        obj.tag,
   443  	}
   444  
   445  	authData := obj.computeAuthData()
   446  
   447  	var plaintext []byte
   448  	recipient := obj.recipients[0]
   449  	recipientHeaders := obj.mergedHeaders(&recipient)
   450  
   451  	cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
   452  	if err == nil {
   453  		// Found a valid CEK -- let's try to decrypt.
   454  		plaintext, err = cipher.decrypt(cek, authData, parts)
   455  	}
   456  
   457  	if plaintext == nil {
   458  		return nil, ErrCryptoFailure
   459  	}
   460  
   461  	// The "zip" header parameter may only be present in the protected header.
   462  	if comp := obj.protected.getCompression(); comp != "" {
   463  		plaintext, err = decompress(comp, plaintext)
   464  	}
   465  
   466  	return plaintext, err
   467  }
   468  
   469  // DecryptMulti decrypts and validates the object and returns the plaintexts,
   470  // with support for multiple recipients. It returns the index of the recipient
   471  // for which the decryption was successful, the merged headers for that recipient,
   472  // and the plaintext.
   473  func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) {
   474  	globalHeaders := obj.mergedHeaders(nil)
   475  
   476  	critical, err := globalHeaders.getCritical()
   477  	if err != nil {
   478  		return -1, Header{}, nil, fmt.Errorf("square/go-jose: invalid crit header")
   479  	}
   480  
   481  	if len(critical) > 0 {
   482  		return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported crit header")
   483  	}
   484  
   485  	decrypter, err := newDecrypter(decryptionKey)
   486  	if err != nil {
   487  		return -1, Header{}, nil, err
   488  	}
   489  
   490  	encryption := globalHeaders.getEncryption()
   491  	cipher := getContentCipher(encryption)
   492  	if cipher == nil {
   493  		return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(encryption))
   494  	}
   495  
   496  	generator := randomKeyGenerator{
   497  		size: cipher.keySize(),
   498  	}
   499  
   500  	parts := &aeadParts{
   501  		iv:         obj.iv,
   502  		ciphertext: obj.ciphertext,
   503  		tag:        obj.tag,
   504  	}
   505  
   506  	authData := obj.computeAuthData()
   507  
   508  	index := -1
   509  	var plaintext []byte
   510  	var headers rawHeader
   511  
   512  	for i, recipient := range obj.recipients {
   513  		recipientHeaders := obj.mergedHeaders(&recipient)
   514  
   515  		cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
   516  		if err == nil {
   517  			// Found a valid CEK -- let's try to decrypt.
   518  			plaintext, err = cipher.decrypt(cek, authData, parts)
   519  			if err == nil {
   520  				index = i
   521  				headers = recipientHeaders
   522  				break
   523  			}
   524  		}
   525  	}
   526  
   527  	if plaintext == nil || err != nil {
   528  		return -1, Header{}, nil, ErrCryptoFailure
   529  	}
   530  
   531  	// The "zip" header parameter may only be present in the protected header.
   532  	if comp := obj.protected.getCompression(); comp != "" {
   533  		plaintext, err = decompress(comp, plaintext)
   534  	}
   535  
   536  	sanitized, err := headers.sanitized()
   537  	if err != nil {
   538  		return -1, Header{}, nil, fmt.Errorf("square/go-jose: failed to sanitize header: %v", err)
   539  	}
   540  
   541  	return index, sanitized, plaintext, err
   542  }
   543  

View as plain text