...

Source file src/github.com/lestrrat-go/jwx/jwe/encrypt.go

Documentation: github.com/lestrrat-go/jwx/jwe

     1  package jwe
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	"github.com/lestrrat-go/jwx/internal/base64"
     8  	"github.com/lestrrat-go/jwx/jwa"
     9  	"github.com/pkg/errors"
    10  )
    11  
    12  var encryptCtxPool = sync.Pool{
    13  	New: func() interface{} {
    14  		return &encryptCtx{}
    15  	},
    16  }
    17  
    18  func getEncryptCtx() *encryptCtx {
    19  	//nolint:forcetypeassert
    20  	return encryptCtxPool.Get().(*encryptCtx)
    21  }
    22  
    23  func releaseEncryptCtx(ctx *encryptCtx) {
    24  	ctx.protected = nil
    25  	ctx.contentEncrypter = nil
    26  	ctx.generator = nil
    27  	ctx.keyEncrypters = nil
    28  	ctx.compress = jwa.NoCompress
    29  	encryptCtxPool.Put(ctx)
    30  }
    31  
    32  // Encrypt takes the plaintext and encrypts into a JWE message.
    33  func (e encryptCtx) Encrypt(plaintext []byte) (*Message, error) {
    34  	bk, err := e.generator.Generate()
    35  	if err != nil {
    36  		return nil, errors.Wrap(err, "failed to generate key")
    37  	}
    38  	cek := bk.Bytes()
    39  
    40  	if e.protected == nil {
    41  		// shouldn't happen, but...
    42  		e.protected = NewHeaders()
    43  	}
    44  
    45  	if err := e.protected.Set(ContentEncryptionKey, e.contentEncrypter.Algorithm()); err != nil {
    46  		return nil, errors.Wrap(err, `failed to set "enc" in protected header`)
    47  	}
    48  
    49  	compression := e.compress
    50  	if compression != jwa.NoCompress {
    51  		if err := e.protected.Set(CompressionKey, compression); err != nil {
    52  			return nil, errors.Wrap(err, `failed to set "zip" in protected header`)
    53  		}
    54  	}
    55  
    56  	// In JWE, multiple recipients may exist -- they receive an
    57  	// encrypted version of the CEK, using their key encryption
    58  	// algorithm of choice.
    59  	recipients := make([]Recipient, len(e.keyEncrypters))
    60  	for i, enc := range e.keyEncrypters {
    61  		r := NewRecipient()
    62  		if err := r.Headers().Set(AlgorithmKey, enc.Algorithm()); err != nil {
    63  			return nil, errors.Wrap(err, "failed to set header")
    64  		}
    65  		if v := enc.KeyID(); v != "" {
    66  			if err := r.Headers().Set(KeyIDKey, v); err != nil {
    67  				return nil, errors.Wrap(err, "failed to set header")
    68  			}
    69  		}
    70  
    71  		enckey, err := enc.Encrypt(cek)
    72  		if err != nil {
    73  			return nil, errors.Wrap(err, `failed to encrypt key`)
    74  		}
    75  		if enc.Algorithm() == jwa.ECDH_ES || enc.Algorithm() == jwa.DIRECT {
    76  			if len(e.keyEncrypters) > 1 {
    77  				return nil, errors.Errorf("unable to support multiple recipients for ECDH-ES")
    78  			}
    79  			cek = enckey.Bytes()
    80  		} else {
    81  			if err := r.SetEncryptedKey(enckey.Bytes()); err != nil {
    82  				return nil, errors.Wrap(err, "failed to set encrypted key")
    83  			}
    84  		}
    85  		if hp, ok := enckey.(populater); ok {
    86  			if err := hp.Populate(r.Headers()); err != nil {
    87  				return nil, errors.Wrap(err, "failed to populate")
    88  			}
    89  		}
    90  		recipients[i] = r
    91  	}
    92  
    93  	// If there's only one recipient, you want to include that in the
    94  	// protected header
    95  	if len(recipients) == 1 {
    96  		h, err := e.protected.Merge(context.TODO(), recipients[0].Headers())
    97  		if err != nil {
    98  			return nil, errors.Wrap(err, "failed to merge protected headers")
    99  		}
   100  		e.protected = h
   101  	}
   102  
   103  	aad, err := e.protected.Encode()
   104  	if err != nil {
   105  		return nil, errors.Wrap(err, "failed to base64 encode protected headers")
   106  	}
   107  
   108  	plaintext, err = compress(plaintext, compression)
   109  	if err != nil {
   110  		return nil, errors.Wrap(err, `failed to compress payload before encryption`)
   111  	}
   112  
   113  	// ...on the other hand, there's only one content cipher.
   114  	iv, ciphertext, tag, err := e.contentEncrypter.Encrypt(cek, plaintext, aad)
   115  	if err != nil {
   116  		return nil, errors.Wrap(err, "failed to encrypt payload")
   117  	}
   118  
   119  	msg := NewMessage()
   120  
   121  	decodedAad, err := base64.Decode(aad)
   122  	if err != nil {
   123  		return nil, errors.Wrap(err, "failed to decode base64")
   124  	}
   125  	if err := msg.Set(AuthenticatedDataKey, decodedAad); err != nil {
   126  		return nil, errors.Wrapf(err, `failed to set %s`, AuthenticatedDataKey)
   127  	}
   128  	if err := msg.Set(CipherTextKey, ciphertext); err != nil {
   129  		return nil, errors.Wrapf(err, `failed to set %s`, CipherTextKey)
   130  	}
   131  	if err := msg.Set(InitializationVectorKey, iv); err != nil {
   132  		return nil, errors.Wrapf(err, `failed to set %s`, InitializationVectorKey)
   133  	}
   134  	if err := msg.Set(ProtectedHeadersKey, e.protected); err != nil {
   135  		return nil, errors.Wrapf(err, `failed to set %s`, ProtectedHeadersKey)
   136  	}
   137  	if err := msg.Set(RecipientsKey, recipients); err != nil {
   138  		return nil, errors.Wrapf(err, `failed to set %s`, RecipientsKey)
   139  	}
   140  	if err := msg.Set(TagKey, tag); err != nil {
   141  		return nil, errors.Wrapf(err, `failed to set %s`, TagKey)
   142  	}
   143  
   144  	return msg, nil
   145  }
   146  

View as plain text