...

Source file src/github.com/lestrrat-go/jwx/jwe/internal/cipher/cipher.go

Documentation: github.com/lestrrat-go/jwx/jwe/internal/cipher

     1  package cipher
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"fmt"
     7  
     8  	"github.com/lestrrat-go/jwx/jwa"
     9  	"github.com/lestrrat-go/jwx/jwe/internal/aescbc"
    10  	"github.com/lestrrat-go/jwx/jwe/internal/keygen"
    11  	"github.com/pkg/errors"
    12  )
    13  
    14  var gcm = &gcmFetcher{}
    15  var cbc = &cbcFetcher{}
    16  
    17  func (f gcmFetcher) Fetch(key []byte) (cipher.AEAD, error) {
    18  	aescipher, err := aes.NewCipher(key)
    19  	if err != nil {
    20  		return nil, errors.Wrap(err, "cipher: failed to create AES cipher for GCM")
    21  	}
    22  
    23  	aead, err := cipher.NewGCM(aescipher)
    24  	if err != nil {
    25  		return nil, errors.Wrap(err, `failed to create GCM for cipher`)
    26  	}
    27  	return aead, nil
    28  }
    29  
    30  func (f cbcFetcher) Fetch(key []byte) (cipher.AEAD, error) {
    31  	aead, err := aescbc.New(key, aes.NewCipher)
    32  	if err != nil {
    33  		return nil, errors.Wrap(err, "cipher: failed to create AES cipher for CBC")
    34  	}
    35  	return aead, nil
    36  }
    37  
    38  func (c AesContentCipher) KeySize() int {
    39  	return c.keysize
    40  }
    41  
    42  func (c AesContentCipher) TagSize() int {
    43  	return c.tagsize
    44  }
    45  
    46  func NewAES(alg jwa.ContentEncryptionAlgorithm) (*AesContentCipher, error) {
    47  	var keysize int
    48  	var tagsize int
    49  	var fetcher Fetcher
    50  	switch alg {
    51  	case jwa.A128GCM:
    52  		keysize = 16
    53  		tagsize = 16
    54  		fetcher = gcm
    55  	case jwa.A192GCM:
    56  		keysize = 24
    57  		tagsize = 16
    58  		fetcher = gcm
    59  	case jwa.A256GCM:
    60  		keysize = 32
    61  		tagsize = 16
    62  		fetcher = gcm
    63  	case jwa.A128CBC_HS256:
    64  		tagsize = 16
    65  		keysize = tagsize * 2
    66  		fetcher = cbc
    67  	case jwa.A192CBC_HS384:
    68  		tagsize = 24
    69  		keysize = tagsize * 2
    70  		fetcher = cbc
    71  	case jwa.A256CBC_HS512:
    72  		tagsize = 32
    73  		keysize = tagsize * 2
    74  		fetcher = cbc
    75  	default:
    76  		return nil, errors.Errorf("failed to create AES content cipher: invalid algorithm (%s)", alg)
    77  	}
    78  
    79  	return &AesContentCipher{
    80  		keysize: keysize,
    81  		tagsize: tagsize,
    82  		fetch:   fetcher,
    83  	}, nil
    84  }
    85  
    86  func (c AesContentCipher) Encrypt(cek, plaintext, aad []byte) (iv, ciphertext, tag []byte, err error) {
    87  	var aead cipher.AEAD
    88  	aead, err = c.fetch.Fetch(cek)
    89  	if err != nil {
    90  		return nil, nil, nil, errors.Wrap(err, "failed to fetch AEAD")
    91  	}
    92  
    93  	// Seal may panic (argh!), so protect ourselves from that
    94  	defer func() {
    95  		if e := recover(); e != nil {
    96  			switch e := e.(type) {
    97  			case error:
    98  				err = e
    99  			default:
   100  				err = errors.Errorf("%s", e)
   101  			}
   102  			err = errors.Wrap(err, "failed to encrypt")
   103  		}
   104  	}()
   105  
   106  	var bs keygen.ByteSource
   107  	if c.NonceGenerator == nil {
   108  		bs, err = keygen.NewRandom(aead.NonceSize()).Generate()
   109  	} else {
   110  		bs, err = c.NonceGenerator.Generate()
   111  	}
   112  	if err != nil {
   113  		return nil, nil, nil, errors.Wrap(err, "failed to generate nonce")
   114  	}
   115  	iv = bs.Bytes()
   116  
   117  	combined := aead.Seal(nil, iv, plaintext, aad)
   118  	tagoffset := len(combined) - c.TagSize()
   119  
   120  	if tagoffset < 0 {
   121  		panic(fmt.Sprintf("tag offset is less than 0 (combined len = %d, tagsize = %d)", len(combined), c.TagSize()))
   122  	}
   123  
   124  	tag = combined[tagoffset:]
   125  	ciphertext = make([]byte, tagoffset)
   126  	copy(ciphertext, combined[:tagoffset])
   127  
   128  	return
   129  }
   130  
   131  func (c AesContentCipher) Decrypt(cek, iv, ciphertxt, tag, aad []byte) (plaintext []byte, err error) {
   132  	aead, err := c.fetch.Fetch(cek)
   133  	if err != nil {
   134  		return nil, errors.Wrap(err, "failed to fetch AEAD data")
   135  	}
   136  
   137  	// Open may panic (argh!), so protect ourselves from that
   138  	defer func() {
   139  		if e := recover(); e != nil {
   140  			switch e := e.(type) {
   141  			case error:
   142  				err = e
   143  			default:
   144  				err = errors.Errorf("%s", e)
   145  			}
   146  			err = errors.Wrap(err, "failed to decrypt")
   147  			return
   148  		}
   149  	}()
   150  
   151  	combined := make([]byte, len(ciphertxt)+len(tag))
   152  	copy(combined, ciphertxt)
   153  	copy(combined[len(ciphertxt):], tag)
   154  
   155  	buf, aeaderr := aead.Open(nil, iv, combined, aad)
   156  	if aeaderr != nil {
   157  		err = errors.Wrap(aeaderr, `aead.Open failed`)
   158  		return
   159  	}
   160  	plaintext = buf
   161  	return
   162  }
   163  

View as plain text