...

Source file src/github.com/lestrrat-go/jwx/jwe/internal/concatkdf/concatkdf.go

Documentation: github.com/lestrrat-go/jwx/jwe/internal/concatkdf

     1  package concatkdf
     2  
     3  import (
     4  	"crypto"
     5  	"encoding/binary"
     6  
     7  	"github.com/pkg/errors"
     8  )
     9  
    10  type KDF struct {
    11  	buf       []byte
    12  	otherinfo []byte
    13  	z         []byte
    14  	hash      crypto.Hash
    15  }
    16  
    17  func ndata(src []byte) []byte {
    18  	buf := make([]byte, 4+len(src))
    19  	binary.BigEndian.PutUint32(buf, uint32(len(src)))
    20  	copy(buf[4:], src)
    21  	return buf
    22  }
    23  
    24  func New(hash crypto.Hash, alg, Z, apu, apv, pubinfo, privinfo []byte) *KDF {
    25  	algbuf := ndata(alg)
    26  	apubuf := ndata(apu)
    27  	apvbuf := ndata(apv)
    28  
    29  	concat := make([]byte, len(algbuf)+len(apubuf)+len(apvbuf)+len(pubinfo)+len(privinfo))
    30  	n := copy(concat, algbuf)
    31  	n += copy(concat[n:], apubuf)
    32  	n += copy(concat[n:], apvbuf)
    33  	n += copy(concat[n:], pubinfo)
    34  	copy(concat[n:], privinfo)
    35  
    36  	return &KDF{
    37  		hash:      hash,
    38  		otherinfo: concat,
    39  		z:         Z,
    40  	}
    41  }
    42  
    43  func (k *KDF) Read(out []byte) (int, error) {
    44  	var round uint32 = 1
    45  	h := k.hash.New()
    46  
    47  	for len(out) > len(k.buf) {
    48  		h.Reset()
    49  
    50  		if err := binary.Write(h, binary.BigEndian, round); err != nil {
    51  			return 0, errors.Wrap(err, "failed to write round using kdf")
    52  		}
    53  		if _, err := h.Write(k.z); err != nil {
    54  			return 0, errors.Wrap(err, "failed to write z using kdf")
    55  		}
    56  		if _, err := h.Write(k.otherinfo); err != nil {
    57  			return 0, errors.Wrap(err, "failed to write other info using kdf")
    58  		}
    59  
    60  		k.buf = append(k.buf, h.Sum(nil)...)
    61  		round++
    62  	}
    63  
    64  	n := copy(out, k.buf[:len(out)])
    65  	k.buf = k.buf[len(out):]
    66  	return n, nil
    67  }
    68  

View as plain text