...

Source file src/github.com/ProtonMail/go-crypto/openpgp/ecdh/ecdh_test.go

Documentation: github.com/ProtonMail/go-crypto/openpgp/ecdh

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package ecdh implements ECDH encryption, suitable for OpenPGP,
     6  // as specified in RFC 6637, section 8.
     7  package ecdh
     8  
     9  import (
    10  	"bytes"
    11  	"crypto/rand"
    12  	"github.com/ProtonMail/go-crypto/openpgp/internal/ecc"
    13  	"io"
    14  	"testing"
    15  
    16  	"github.com/ProtonMail/go-crypto/openpgp/internal/algorithm"
    17  )
    18  
    19  func TestCurves(t *testing.T) {
    20  	for _, curve := range ecc.Curves {
    21  		ECDHCurve, ok := curve.Curve.(ecc.ECDHCurve)
    22  		if !ok {
    23  			continue
    24  		}
    25  
    26  		t.Run(ECDHCurve.GetCurveName(), func(t *testing.T) {
    27  			testFingerprint := make([]byte, 20)
    28  			_, err := io.ReadFull(rand.Reader, testFingerprint[:])
    29  			if err != nil {
    30  				t.Fatal(err)
    31  			}
    32  
    33  			priv := testGenerate(t, ECDHCurve)
    34  			testEncryptDecrypt(t, priv, curve.Oid.Bytes(), testFingerprint)
    35  			testValidation(t, priv)
    36  
    37  			// Needs fresh key
    38  			priv = testGenerate(t, ECDHCurve)
    39  			testMarshalUnmarshal(t, priv)
    40  		})
    41  	}
    42  }
    43  
    44  func testGenerate(t *testing.T, curve ecc.ECDHCurve) *PrivateKey {
    45  	kdf := KDF{
    46  		Hash:   algorithm.SHA512,
    47  		Cipher: algorithm.AES256,
    48  	}
    49  
    50  	priv, err := GenerateKey(rand.Reader, curve, kdf)
    51  	if err != nil {
    52  		t.Fatal(err)
    53  	}
    54  
    55  	return priv
    56  }
    57  
    58  func testEncryptDecrypt(t *testing.T, priv *PrivateKey, oid, fingerprint []byte) {
    59  	message := []byte("hello world")
    60  
    61  	vsG, m, err := Encrypt(rand.Reader, &priv.PublicKey, message, oid, fingerprint)
    62  	if err != nil {
    63  		t.Errorf("error encrypting: %s", err)
    64  	}
    65  
    66  	message2, err := Decrypt(priv, vsG, m, oid, fingerprint)
    67  	if err != nil {
    68  		t.Errorf("error decrypting: %s", err)
    69  	}
    70  
    71  	if !bytes.Equal(message2, message) {
    72  		t.Errorf("decryption failed, got: %x, want: %x", message2, message)
    73  	}
    74  }
    75  
    76  func testValidation(t *testing.T, priv *PrivateKey) {
    77  	if err := Validate(priv); err != nil {
    78  		t.Fatalf("valid key marked as invalid: %s", err)
    79  	}
    80  
    81  	priv.D[5] ^= 1
    82  	if err := Validate(priv); err == nil {
    83  		t.Fatalf("failed to detect invalid key")
    84  	}
    85  }
    86  
    87  func testMarshalUnmarshal(t *testing.T, priv *PrivateKey) {
    88  	p := priv.MarshalPoint()
    89  	d := priv.MarshalByteSecret()
    90  
    91  	parsed := NewPrivateKey(*NewPublicKey(priv.GetCurve(), priv.KDF.Hash, priv.KDF.Cipher))
    92  
    93  	if err := parsed.UnmarshalPoint(p); err != nil {
    94  		t.Fatalf("unable to unmarshal point: %s", err)
    95  	}
    96  
    97  	if err := parsed.UnmarshalByteSecret(d); err != nil {
    98  		t.Fatalf("unable to unmarshal integer: %s", err)
    99  	}
   100  
   101  	expectedD := make([]byte, len(priv.D))
   102  	copy(expectedD, priv.D)
   103  
   104  	// Curve25519 expects keys to be saved clamped
   105  	if priv.curve.GetCurveName() == "curve25519" {
   106  		expectedD[0] &= 248
   107  		expectedD[31] &= 127
   108  		expectedD[31] |= 64
   109  	}
   110  
   111  	if !bytes.Equal(priv.Point, parsed.Point) || !bytes.Equal(expectedD, parsed.D) {
   112  		t.Fatal("failed to marshal/unmarshal correctly")
   113  	}
   114  }
   115  

View as plain text