...

Source file src/github.com/lestrrat-go/jwx/jwe/jwe.go

Documentation: github.com/lestrrat-go/jwx/jwe

     1  //go:generate ./gen.sh
     2  
     3  // Package jwe implements JWE as described in https://tools.ietf.org/html/rfc7516
     4  package jwe
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/ecdsa"
     9  	"crypto/rsa"
    10  	"io"
    11  	"io/ioutil"
    12  
    13  	"github.com/lestrrat-go/jwx/internal/base64"
    14  	"github.com/lestrrat-go/jwx/internal/json"
    15  	"github.com/lestrrat-go/jwx/internal/keyconv"
    16  	"github.com/lestrrat-go/jwx/jwk"
    17  
    18  	"github.com/lestrrat-go/jwx/jwa"
    19  	"github.com/lestrrat-go/jwx/jwe/internal/content_crypt"
    20  	"github.com/lestrrat-go/jwx/jwe/internal/keyenc"
    21  	"github.com/lestrrat-go/jwx/jwe/internal/keygen"
    22  	"github.com/lestrrat-go/jwx/x25519"
    23  	"github.com/pkg/errors"
    24  )
    25  
    26  var registry = json.NewRegistry()
    27  
    28  // Encrypt takes the plaintext payload and encrypts it in JWE compact format.
    29  // `key` should be a public key, and it may be a raw key (e.g. rsa.PublicKey) or a jwk.Key
    30  //
    31  // Encrypt currently does not support multi-recipient messages.
    32  func Encrypt(payload []byte, keyalg jwa.KeyEncryptionAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) ([]byte, error) {
    33  	var protected Headers
    34  	for _, option := range options {
    35  		//nolint:forcetypeassert
    36  		switch option.Ident() {
    37  		case identProtectedHeader{}:
    38  			protected = option.Value().(Headers)
    39  		}
    40  	}
    41  	if protected == nil {
    42  		protected = NewHeaders()
    43  	}
    44  
    45  	contentcrypt, err := content_crypt.NewGeneric(contentalg)
    46  	if err != nil {
    47  		return nil, errors.Wrap(err, `failed to create AES encrypter`)
    48  	}
    49  
    50  	var keyID string
    51  	if jwkKey, ok := key.(jwk.Key); ok {
    52  		keyID = jwkKey.KeyID()
    53  
    54  		var raw interface{}
    55  		if err := jwkKey.Raw(&raw); err != nil {
    56  			return nil, errors.Wrapf(err, `failed to retrieve raw key out of %T`, key)
    57  		}
    58  
    59  		key = raw
    60  	}
    61  
    62  	var enc keyenc.Encrypter
    63  	switch keyalg {
    64  	case jwa.RSA1_5:
    65  		var pubkey rsa.PublicKey
    66  		if err := keyconv.RSAPublicKey(&pubkey, key); err != nil {
    67  			return nil, errors.Wrapf(err, "failed to generate public key from key (%T)", key)
    68  		}
    69  
    70  		enc, err = keyenc.NewRSAPKCSEncrypt(keyalg, &pubkey)
    71  		if err != nil {
    72  			return nil, errors.Wrap(err, "failed to create RSA PKCS encrypter")
    73  		}
    74  	case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
    75  		var pubkey rsa.PublicKey
    76  		if err := keyconv.RSAPublicKey(&pubkey, key); err != nil {
    77  			return nil, errors.Wrapf(err, "failed to generate public key from key (%T)", key)
    78  		}
    79  
    80  		enc, err = keyenc.NewRSAOAEPEncrypt(keyalg, &pubkey)
    81  		if err != nil {
    82  			return nil, errors.Wrap(err, "failed to create RSA OAEP encrypter")
    83  		}
    84  	case jwa.A128KW, jwa.A192KW, jwa.A256KW,
    85  		jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW,
    86  		jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
    87  		sharedkey, ok := key.([]byte)
    88  		if !ok {
    89  			return nil, errors.New("invalid key: []byte required")
    90  		}
    91  		switch keyalg {
    92  		case jwa.A128KW, jwa.A192KW, jwa.A256KW:
    93  			enc, err = keyenc.NewAES(keyalg, sharedkey)
    94  		case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
    95  			enc, err = keyenc.NewPBES2Encrypt(keyalg, sharedkey)
    96  		default:
    97  			enc, err = keyenc.NewAESGCMEncrypt(keyalg, sharedkey)
    98  		}
    99  		if err != nil {
   100  			return nil, errors.Wrap(err, "failed to create key wrap encrypter")
   101  		}
   102  		// NOTE: there was formerly a restriction, introduced
   103  		// in PR #26, which disallowed certain key/content
   104  		// algorithm combinations. This seemed bogus, and
   105  		// interop with the jose tool demonstrates it.
   106  	case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
   107  		var keysize int
   108  		switch keyalg {
   109  		case jwa.ECDH_ES:
   110  			// https://tools.ietf.org/html/rfc7518#page-15
   111  			// In Direct Key Agreement mode, the output of the Concat KDF MUST be a
   112  			// key of the same length as that used by the "enc" algorithm.
   113  			keysize = contentcrypt.KeySize()
   114  		case jwa.ECDH_ES_A128KW:
   115  			keysize = 16
   116  		case jwa.ECDH_ES_A192KW:
   117  			keysize = 24
   118  		case jwa.ECDH_ES_A256KW:
   119  			keysize = 32
   120  		}
   121  
   122  		switch key := key.(type) {
   123  		case x25519.PublicKey:
   124  			enc, err = keyenc.NewECDHESEncrypt(keyalg, contentalg, keysize, key)
   125  		default:
   126  			var pubkey ecdsa.PublicKey
   127  			if err := keyconv.ECDSAPublicKey(&pubkey, key); err != nil {
   128  				return nil, errors.Wrapf(err, "failed to generate public key from key (%T)", key)
   129  			}
   130  			enc, err = keyenc.NewECDHESEncrypt(keyalg, contentalg, keysize, &pubkey)
   131  		}
   132  		if err != nil {
   133  			return nil, errors.Wrap(err, "failed to create ECDHS key wrap encrypter")
   134  		}
   135  	case jwa.DIRECT:
   136  		sharedkey, ok := key.([]byte)
   137  		if !ok {
   138  			return nil, errors.New("invalid key: []byte required")
   139  		}
   140  		enc, _ = keyenc.NewNoop(keyalg, sharedkey)
   141  	default:
   142  		return nil, errors.Errorf(`invalid key encryption algorithm (%s)`, keyalg)
   143  	}
   144  
   145  	if keyID != "" {
   146  		enc.SetKeyID(keyID)
   147  	}
   148  
   149  	keysize := contentcrypt.KeySize()
   150  	encctx := getEncryptCtx()
   151  	defer releaseEncryptCtx(encctx)
   152  
   153  	encctx.protected = protected
   154  	encctx.contentEncrypter = contentcrypt
   155  	encctx.generator = keygen.NewRandom(keysize)
   156  	encctx.keyEncrypters = []keyenc.Encrypter{enc}
   157  	encctx.compress = compressalg
   158  	msg, err := encctx.Encrypt(payload)
   159  	if err != nil {
   160  		return nil, errors.Wrap(err, "failed to encrypt payload")
   161  	}
   162  
   163  	return Compact(msg)
   164  }
   165  
   166  // DecryptCtx is used internally when jwe.Decrypt is called, and is
   167  // passed for hooks that you may pass into it.
   168  //
   169  // Regular users should not have to touch this object, but if you need advanced handling
   170  // of messages, you might have to use it. Only use it when you really
   171  // understand how JWE processing works in this library.
   172  type DecryptCtx interface {
   173  	Algorithm() jwa.KeyEncryptionAlgorithm
   174  	SetAlgorithm(jwa.KeyEncryptionAlgorithm)
   175  	Key() interface{}
   176  	SetKey(interface{})
   177  	Message() *Message
   178  	SetMessage(*Message)
   179  }
   180  
   181  type decryptCtx struct {
   182  	alg jwa.KeyEncryptionAlgorithm
   183  	key interface{}
   184  	msg *Message
   185  }
   186  
   187  func (ctx *decryptCtx) Algorithm() jwa.KeyEncryptionAlgorithm {
   188  	return ctx.alg
   189  }
   190  
   191  func (ctx *decryptCtx) SetAlgorithm(v jwa.KeyEncryptionAlgorithm) {
   192  	ctx.alg = v
   193  }
   194  
   195  func (ctx *decryptCtx) Key() interface{} {
   196  	return ctx.key
   197  }
   198  
   199  func (ctx *decryptCtx) SetKey(v interface{}) {
   200  	ctx.key = v
   201  }
   202  
   203  func (ctx *decryptCtx) Message() *Message {
   204  	return ctx.msg
   205  }
   206  
   207  func (ctx *decryptCtx) SetMessage(m *Message) {
   208  	ctx.msg = m
   209  }
   210  
   211  // Decrypt takes the key encryption algorithm and the corresponding
   212  // key to decrypt the JWE message, and returns the decrypted payload.
   213  // The JWE message can be either compact or full JSON format.
   214  //
   215  // `key` must be a private key. It can be either in its raw format (e.g. *rsa.PrivateKey) or a jwk.Key
   216  func Decrypt(buf []byte, alg jwa.KeyEncryptionAlgorithm, key interface{}, options ...DecryptOption) ([]byte, error) {
   217  	var ctx decryptCtx
   218  	ctx.key = key
   219  	ctx.alg = alg
   220  
   221  	var dst *Message
   222  	var postParse PostParser
   223  	//nolint:forcetypeassert
   224  	for _, option := range options {
   225  		switch option.Ident() {
   226  		case identMessage{}:
   227  			dst = option.Value().(*Message)
   228  		case identPostParser{}:
   229  			postParse = option.Value().(PostParser)
   230  		}
   231  	}
   232  
   233  	msg, err := parseJSONOrCompact(buf, true)
   234  	if err != nil {
   235  		return nil, errors.Wrap(err, "failed to parse buffer for Decrypt")
   236  	}
   237  
   238  	ctx.msg = msg
   239  	if postParse != nil {
   240  		if err := postParse.PostParse(&ctx); err != nil {
   241  			return nil, errors.Wrap(err, `failed to execute PostParser hook`)
   242  		}
   243  	}
   244  
   245  	payload, err := doDecryptCtx(&ctx)
   246  	if err != nil {
   247  		return nil, errors.Wrap(err, `failed to decrypt message`)
   248  	}
   249  
   250  	if dst != nil {
   251  		*dst = *msg
   252  		dst.rawProtectedHeaders = nil
   253  		dst.storeProtectedHeaders = false
   254  	}
   255  
   256  	return payload, nil
   257  }
   258  
   259  // Parse parses the JWE message into a Message object. The JWE message
   260  // can be either compact or full JSON format.
   261  func Parse(buf []byte) (*Message, error) {
   262  	return parseJSONOrCompact(buf, false)
   263  }
   264  
   265  func parseJSONOrCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) {
   266  	buf = bytes.TrimSpace(buf)
   267  	if len(buf) == 0 {
   268  		return nil, errors.New("empty buffer")
   269  	}
   270  
   271  	if buf[0] == '{' {
   272  		return parseJSON(buf, storeProtectedHeaders)
   273  	}
   274  	return parseCompact(buf, storeProtectedHeaders)
   275  }
   276  
   277  // ParseString is the same as Parse, but takes a string.
   278  func ParseString(s string) (*Message, error) {
   279  	return Parse([]byte(s))
   280  }
   281  
   282  // ParseReader is the same as Parse, but takes an io.Reader.
   283  func ParseReader(src io.Reader) (*Message, error) {
   284  	buf, err := ioutil.ReadAll(src)
   285  	if err != nil {
   286  		return nil, errors.Wrap(err, `failed to read from io.Reader`)
   287  	}
   288  	return Parse(buf)
   289  }
   290  
   291  func parseJSON(buf []byte, storeProtectedHeaders bool) (*Message, error) {
   292  	m := NewMessage()
   293  	m.storeProtectedHeaders = storeProtectedHeaders
   294  	if err := json.Unmarshal(buf, &m); err != nil {
   295  		return nil, errors.Wrap(err, "failed to parse JSON")
   296  	}
   297  	return m, nil
   298  }
   299  
   300  func parseCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) {
   301  	parts := bytes.Split(buf, []byte{'.'})
   302  	if len(parts) != 5 {
   303  		return nil, errors.Errorf(`compact JWE format must have five parts (%d)`, len(parts))
   304  	}
   305  
   306  	hdrbuf, err := base64.Decode(parts[0])
   307  	if err != nil {
   308  		return nil, errors.Wrap(err, `failed to parse first part of compact form`)
   309  	}
   310  
   311  	protected := NewHeaders()
   312  	if err := json.Unmarshal(hdrbuf, protected); err != nil {
   313  		return nil, errors.Wrap(err, "failed to parse header JSON")
   314  	}
   315  
   316  	ivbuf, err := base64.Decode(parts[2])
   317  	if err != nil {
   318  		return nil, errors.Wrap(err, "failed to base64 decode iv")
   319  	}
   320  
   321  	ctbuf, err := base64.Decode(parts[3])
   322  	if err != nil {
   323  		return nil, errors.Wrap(err, "failed to base64 decode content")
   324  	}
   325  
   326  	tagbuf, err := base64.Decode(parts[4])
   327  	if err != nil {
   328  		return nil, errors.Wrap(err, "failed to base64 decode tag")
   329  	}
   330  
   331  	m := NewMessage()
   332  	if err := m.Set(CipherTextKey, ctbuf); err != nil {
   333  		return nil, errors.Wrapf(err, `failed to set %s`, CipherTextKey)
   334  	}
   335  	if err := m.Set(InitializationVectorKey, ivbuf); err != nil {
   336  		return nil, errors.Wrapf(err, `failed to set %s`, InitializationVectorKey)
   337  	}
   338  	if err := m.Set(ProtectedHeadersKey, protected); err != nil {
   339  		return nil, errors.Wrapf(err, `failed to set %s`, ProtectedHeadersKey)
   340  	}
   341  
   342  	if err := m.makeDummyRecipient(string(parts[1]), protected); err != nil {
   343  		return nil, errors.Wrap(err, `failed to setup recipient`)
   344  	}
   345  
   346  	if err := m.Set(TagKey, tagbuf); err != nil {
   347  		return nil, errors.Wrapf(err, `failed to set %s`, TagKey)
   348  	}
   349  
   350  	if storeProtectedHeaders {
   351  		// This is later used for decryption.
   352  		m.rawProtectedHeaders = parts[0]
   353  	}
   354  
   355  	return m, nil
   356  }
   357  
   358  // RegisterCustomField allows users to specify that a private field
   359  // be decoded as an instance of the specified type. This option has
   360  // a global effect.
   361  //
   362  // For example, suppose you have a custom field `x-birthday`, which
   363  // you want to represent as a string formatted in RFC3339 in JSON,
   364  // but want it back as `time.Time`.
   365  //
   366  // In that case you would register a custom field as follows
   367  //
   368  //   jwe.RegisterCustomField(`x-birthday`, timeT)
   369  //
   370  // Then `hdr.Get("x-birthday")` will still return an `interface{}`,
   371  // but you can convert its type to `time.Time`
   372  //
   373  //   bdayif, _ := hdr.Get(`x-birthday`)
   374  //   bday := bdayif.(time.Time)
   375  func RegisterCustomField(name string, object interface{}) {
   376  	registry.Register(name, object)
   377  }
   378  

View as plain text