...

Source file src/github.com/ProtonMail/go-crypto/openpgp/packet/symmetrically_encrypted_test.go

Documentation: github.com/ProtonMail/go-crypto/openpgp/packet

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package packet
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"crypto/sha1"
    11  	"encoding/hex"
    12  	goerrors "errors"
    13  	"io"
    14  	"io/ioutil"
    15  	"testing"
    16  
    17  	"github.com/ProtonMail/go-crypto/openpgp/errors"
    18  )
    19  
    20  // TestReader wraps a []byte and returns reads of a specific length.
    21  type testReader struct {
    22  	data   []byte
    23  	stride int
    24  }
    25  
    26  func (t *testReader) Read(buf []byte) (n int, err error) {
    27  	n = t.stride
    28  	if n > len(t.data) {
    29  		n = len(t.data)
    30  	}
    31  	if n > len(buf) {
    32  		n = len(buf)
    33  	}
    34  
    35  	copy(buf[:n], t.data)
    36  	t.data = t.data[n:]
    37  
    38  	if len(t.data) == 0 {
    39  		err = io.EOF
    40  	}
    41  
    42  	return
    43  }
    44  
    45  const mdcPlaintextHex = "cb1362000000000048656c6c6f2c20776f726c6421d314c23d643f478a9a2098811fcb191e7b24b80966a1"
    46  
    47  func TestMDCReader(t *testing.T) {
    48  	mdcPlaintext, _ := hex.DecodeString(mdcPlaintextHex)
    49  	for stride := 1; stride < len(mdcPlaintext)/2; stride++ {
    50  		r := &testReader{data: mdcPlaintext, stride: stride}
    51  		mdcReader := &seMDCReader{in: r, h: sha1.New()}
    52  		body, err := ioutil.ReadAll(mdcReader)
    53  		if err != nil {
    54  			t.Errorf("stride: %d, error: %s", stride, err)
    55  			continue
    56  		}
    57  		if !bytes.Equal(body, mdcPlaintext[:len(mdcPlaintext)-22]) {
    58  			t.Errorf("stride: %d: bad contents %x", stride, body)
    59  			continue
    60  		}
    61  
    62  		err = mdcReader.Close()
    63  		if err != nil {
    64  			t.Errorf("stride: %d, error on Close: %s", stride, err)
    65  		}
    66  	}
    67  
    68  	mdcPlaintext[15] ^= 80
    69  
    70  	r := &testReader{data: mdcPlaintext, stride: 2}
    71  	mdcReader := &seMDCReader{in: r, h: sha1.New()}
    72  	_, err := ioutil.ReadAll(mdcReader)
    73  	if err != nil {
    74  		t.Errorf("corruption test, error: %s", err)
    75  		return
    76  	}
    77  	err = mdcReader.Close()
    78  	if err == nil {
    79  		t.Error("corruption: no error")
    80  	} else if !goerrors.Is(err, errors.ErrMDCHashMismatch) {
    81  		t.Errorf("corruption: expected SignatureError, got: %s", err)
    82  	}
    83  }
    84  
    85  func TestSerializeMdc(t *testing.T) {
    86  	buf := bytes.NewBuffer(nil)
    87  	c := CipherAES128
    88  	key := make([]byte, c.KeySize())
    89  
    90  	cipherSuite := CipherSuite{
    91  		Cipher: c,
    92  		Mode:   AEADModeOCB,
    93  	}
    94  
    95  	w, err := SerializeSymmetricallyEncrypted(buf, c, false, cipherSuite, key, nil)
    96  	if err != nil {
    97  		t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
    98  		return
    99  	}
   100  
   101  	contents := []byte("hello world\n")
   102  
   103  	w.Write(contents)
   104  	w.Close()
   105  
   106  	p, err := Read(buf)
   107  	if err != nil {
   108  		t.Errorf("error from Read: %s", err)
   109  		return
   110  	}
   111  
   112  	se, ok := p.(*SymmetricallyEncrypted)
   113  	if !ok {
   114  		t.Errorf("didn't read a *SymmetricallyEncrypted")
   115  		return
   116  	}
   117  
   118  	r, err := se.Decrypt(c, key)
   119  	if err != nil {
   120  		t.Errorf("error from Decrypt: %s", err)
   121  		return
   122  	}
   123  
   124  	contentsCopy := bytes.NewBuffer(nil)
   125  	_, err = io.Copy(contentsCopy, r)
   126  	if err != nil {
   127  		t.Errorf("error from io.Copy: %s", err)
   128  		return
   129  	}
   130  	if !bytes.Equal(contentsCopy.Bytes(), contents) {
   131  		t.Errorf("contents not equal got: %x want: %x", contentsCopy.Bytes(), contents)
   132  	}
   133  }
   134  
   135  const aeadHexKey = "1936fc8568980274bb900d8319360c77"
   136  const aeadHexSeipd = "d26902070306fcb94490bcb98bbdc9d106c6090266940f72e89edc21b5596b1576b101ed0f9ffc6fc6d65bbfd24dcd0790966e6d1e85a30053784cb1d8b6a0699ef12155a7b2ad6258531b57651fd7777912fa95e35d9b40216f69a4c248db28ff4331f1632907399e6ff9"
   137  const aeadHexPlainText = "cb1362000000000048656c6c6f2c20776f726c6421d50e1ce2269a9eddef81032172b7ed7c"
   138  const aeadExpectedSalt = "fcb94490bcb98bbdc9d106c6090266940f72e89edc21b5596b1576b101ed0f9f"
   139  
   140  func TestAeadRfcVector(t *testing.T) {
   141  	key, err := hex.DecodeString(aeadHexKey)
   142  	if err != nil {
   143  		t.Errorf("error in decoding key: %s", err)
   144  	}
   145  
   146  	packet, err := hex.DecodeString(aeadHexSeipd)
   147  	if err != nil {
   148  		t.Errorf("error in decoding packet: %s", err)
   149  	}
   150  
   151  	plainText, err := hex.DecodeString(aeadHexPlainText)
   152  	if err != nil {
   153  		t.Errorf("error in decoding plaintext: %s", err)
   154  	}
   155  
   156  	expectedSalt, err := hex.DecodeString(aeadExpectedSalt)
   157  	if err != nil {
   158  		t.Errorf("error in decoding salt: %s", err)
   159  	}
   160  
   161  	buf := bytes.NewBuffer(packet)
   162  	p, err := Read(buf)
   163  	if err != nil {
   164  		t.Errorf("error from Read: %s", err)
   165  		return
   166  	}
   167  
   168  	se, ok := p.(*SymmetricallyEncrypted)
   169  	if !ok {
   170  		t.Errorf("didn't read a *SymmetricallyEncrypted")
   171  		return
   172  	}
   173  
   174  	if se.Version != symmetricallyEncryptedVersionAead {
   175  		t.Errorf("found wrong version, want: %d, got: %d", symmetricallyEncryptedVersionAead, se.Version)
   176  	}
   177  
   178  	if se.Cipher != CipherAES128 {
   179  		t.Errorf("found wrong cipher, want: %d, got: %d", CipherAES128, se.Cipher)
   180  	}
   181  
   182  	if se.Mode != AEADModeGCM {
   183  		t.Errorf("found wrong mode, want: %d, got: %d", AEADModeGCM, se.Mode)
   184  	}
   185  
   186  	if !bytes.Equal(se.Salt[:], expectedSalt) {
   187  		t.Errorf("found wrong salt, want: %x, got: %x", expectedSalt, se.Salt)
   188  	}
   189  
   190  	if se.ChunkSizeByte != 0x06 {
   191  		t.Errorf("found wrong chunk size byte, want: %d, got: %d", 0x06, se.ChunkSizeByte)
   192  	}
   193  
   194  	aeadReader, err := se.Decrypt(CipherFunction(0), key)
   195  	if err != nil {
   196  		t.Errorf("error from Decrypt: %s", err)
   197  		return
   198  	}
   199  
   200  	decrypted, err := ioutil.ReadAll(aeadReader)
   201  	if err != nil {
   202  		t.Errorf("error when reading: %s", err)
   203  		return
   204  	}
   205  
   206  	err = aeadReader.Close()
   207  	if err != nil {
   208  		t.Errorf("error when closing reader: %s", err)
   209  		return
   210  	}
   211  
   212  	if !bytes.Equal(decrypted, plainText) {
   213  		t.Errorf("contents not equal got: %x want: %x", decrypted, plainText)
   214  	}
   215  }
   216  
   217  func TestAeadEncryptDecrypt(t *testing.T) {
   218  	ciphers := map[string]CipherFunction{
   219  		"AES128": CipherAES128,
   220  		"AES192": CipherAES192,
   221  		"AES256": CipherAES256,
   222  	}
   223  
   224  	modes := map[string]AEADMode{
   225  		"EAX": AEADModeEAX,
   226  		"OCB": AEADModeOCB,
   227  		"GCM": AEADModeGCM,
   228  	}
   229  
   230  	for cipherName, cipher := range ciphers {
   231  		t.Run(cipherName, func(t *testing.T) {
   232  			for modeName, mode := range modes {
   233  				t.Run(modeName, func(t *testing.T) {
   234  					testSerializeAead(t, CipherSuite{Cipher: cipher, Mode: mode})
   235  				})
   236  			}
   237  		})
   238  	}
   239  }
   240  
   241  func testSerializeAead(t *testing.T, cipherSuite CipherSuite) {
   242  	buf := bytes.NewBuffer(nil)
   243  	key := make([]byte, cipherSuite.Cipher.KeySize())
   244  	_, _ = rand.Read(key)
   245  
   246  	w, err := SerializeSymmetricallyEncrypted(buf, CipherFunction(0), true, cipherSuite, key, &Config{AEADConfig: &AEADConfig{}})
   247  	if err != nil {
   248  		t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
   249  		return
   250  	}
   251  
   252  	contents := []byte("hello world\n")
   253  
   254  	w.Write(contents)
   255  	w.Close()
   256  
   257  	p, err := Read(buf)
   258  	if err != nil {
   259  		t.Errorf("error from Read: %s", err)
   260  		return
   261  	}
   262  
   263  	se, ok := p.(*SymmetricallyEncrypted)
   264  	if !ok {
   265  		t.Errorf("didn't read a *SymmetricallyEncrypted")
   266  		return
   267  	}
   268  
   269  	if se.Version != symmetricallyEncryptedVersionAead {
   270  		t.Errorf("found wrong version, want: %d, got: %d", symmetricallyEncryptedVersionAead, se.Version)
   271  	}
   272  
   273  	if se.Cipher != cipherSuite.Cipher {
   274  		t.Errorf("found wrong cipher, want: %d, got: %d", cipherSuite.Cipher, se.Cipher)
   275  	}
   276  
   277  	if se.Mode != cipherSuite.Mode {
   278  		t.Errorf("found wrong mode, want: %d, got: %d", cipherSuite.Mode, se.Mode)
   279  	}
   280  
   281  	r, err := se.Decrypt(CipherFunction(0), key)
   282  	if err != nil {
   283  		t.Errorf("error from Decrypt: %s", err)
   284  		return
   285  	}
   286  
   287  	contentsCopy := bytes.NewBuffer(nil)
   288  	_, err = io.Copy(contentsCopy, r)
   289  	if err != nil {
   290  		t.Errorf("error from io.Copy: %s", err)
   291  		return
   292  	}
   293  	if !bytes.Equal(contentsCopy.Bytes(), contents) {
   294  		t.Errorf("contents not equal got: %x want: %x", contentsCopy.Bytes(), contents)
   295  	}
   296  }
   297  

View as plain text