1
16
17 package josecipher
18
19 import (
20 "bytes"
21 "crypto/cipher"
22 "crypto/hmac"
23 "crypto/sha256"
24 "crypto/sha512"
25 "crypto/subtle"
26 "encoding/binary"
27 "errors"
28 "hash"
29 )
30
31 const (
32 nonceBytes = 16
33 )
34
35
36 func NewCBCHMAC(key []byte, newBlockCipher func([]byte) (cipher.Block, error)) (cipher.AEAD, error) {
37 keySize := len(key) / 2
38 integrityKey := key[:keySize]
39 encryptionKey := key[keySize:]
40
41 blockCipher, err := newBlockCipher(encryptionKey)
42 if err != nil {
43 return nil, err
44 }
45
46 var hash func() hash.Hash
47 switch keySize {
48 case 16:
49 hash = sha256.New
50 case 24:
51 hash = sha512.New384
52 case 32:
53 hash = sha512.New
54 }
55
56 return &cbcAEAD{
57 hash: hash,
58 blockCipher: blockCipher,
59 authtagBytes: keySize,
60 integrityKey: integrityKey,
61 }, nil
62 }
63
64
65 type cbcAEAD struct {
66 hash func() hash.Hash
67 authtagBytes int
68 integrityKey []byte
69 blockCipher cipher.Block
70 }
71
72 func (ctx *cbcAEAD) NonceSize() int {
73 return nonceBytes
74 }
75
76 func (ctx *cbcAEAD) Overhead() int {
77
78
79 return ctx.blockCipher.BlockSize() + ctx.authtagBytes
80 }
81
82
83 func (ctx *cbcAEAD) Seal(dst, nonce, plaintext, data []byte) []byte {
84
85 ciphertext := make([]byte, uint64(len(plaintext))+uint64(ctx.Overhead()))[:len(plaintext)]
86 copy(ciphertext, plaintext)
87 ciphertext = padBuffer(ciphertext, ctx.blockCipher.BlockSize())
88
89 cbc := cipher.NewCBCEncrypter(ctx.blockCipher, nonce)
90
91 cbc.CryptBlocks(ciphertext, ciphertext)
92 authtag := ctx.computeAuthTag(data, nonce, ciphertext)
93
94 ret, out := resize(dst, uint64(len(dst))+uint64(len(ciphertext))+uint64(len(authtag)))
95 copy(out, ciphertext)
96 copy(out[len(ciphertext):], authtag)
97
98 return ret
99 }
100
101
102 func (ctx *cbcAEAD) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
103 if len(ciphertext) < ctx.authtagBytes {
104 return nil, errors.New("square/go-jose: invalid ciphertext (too short)")
105 }
106
107 offset := len(ciphertext) - ctx.authtagBytes
108 expectedTag := ctx.computeAuthTag(data, nonce, ciphertext[:offset])
109 match := subtle.ConstantTimeCompare(expectedTag, ciphertext[offset:])
110 if match != 1 {
111 return nil, errors.New("square/go-jose: invalid ciphertext (auth tag mismatch)")
112 }
113
114 cbc := cipher.NewCBCDecrypter(ctx.blockCipher, nonce)
115
116
117 buffer := append([]byte{}, []byte(ciphertext[:offset])...)
118
119 if len(buffer)%ctx.blockCipher.BlockSize() > 0 {
120 return nil, errors.New("square/go-jose: invalid ciphertext (invalid length)")
121 }
122
123 cbc.CryptBlocks(buffer, buffer)
124
125
126 plaintext, err := unpadBuffer(buffer, ctx.blockCipher.BlockSize())
127 if err != nil {
128 return nil, err
129 }
130
131 ret, out := resize(dst, uint64(len(dst))+uint64(len(plaintext)))
132 copy(out, plaintext)
133
134 return ret, nil
135 }
136
137
138 func (ctx *cbcAEAD) computeAuthTag(aad, nonce, ciphertext []byte) []byte {
139 buffer := make([]byte, uint64(len(aad))+uint64(len(nonce))+uint64(len(ciphertext))+8)
140 n := 0
141 n += copy(buffer, aad)
142 n += copy(buffer[n:], nonce)
143 n += copy(buffer[n:], ciphertext)
144 binary.BigEndian.PutUint64(buffer[n:], uint64(len(aad))*8)
145
146
147 hmac := hmac.New(ctx.hash, ctx.integrityKey)
148 _, _ = hmac.Write(buffer)
149
150 return hmac.Sum(nil)[:ctx.authtagBytes]
151 }
152
153
154
155
156 func resize(in []byte, n uint64) (head, tail []byte) {
157 if uint64(cap(in)) >= n {
158 head = in[:n]
159 } else {
160 head = make([]byte, n)
161 copy(head, in)
162 }
163
164 tail = head[len(in):]
165 return
166 }
167
168
169 func padBuffer(buffer []byte, blockSize int) []byte {
170 missing := blockSize - (len(buffer) % blockSize)
171 ret, out := resize(buffer, uint64(len(buffer))+uint64(missing))
172 padding := bytes.Repeat([]byte{byte(missing)}, missing)
173 copy(out, padding)
174 return ret
175 }
176
177
178 func unpadBuffer(buffer []byte, blockSize int) ([]byte, error) {
179 if len(buffer)%blockSize != 0 {
180 return nil, errors.New("square/go-jose: invalid padding")
181 }
182
183 last := buffer[len(buffer)-1]
184 count := int(last)
185
186 if count == 0 || count > blockSize || count > len(buffer) {
187 return nil, errors.New("square/go-jose: invalid padding")
188 }
189
190 padding := bytes.Repeat([]byte{last}, count)
191 if !bytes.HasSuffix(buffer, padding) {
192 return nil, errors.New("square/go-jose: invalid padding")
193 }
194
195 return buffer[:len(buffer)-count], nil
196 }
197
View as plain text