...

Source file src/edge-infra.dev/pkg/edge/datasync/magpie/server_test.go

Documentation: edge-infra.dev/pkg/edge/datasync/magpie

     1  package magpie
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/rsa"
     7  	"crypto/sha256"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"os"
    13  	"strings"
    14  	"testing"
    15  
    16  	"edge-infra.dev/pkg/edge/edgeencrypt"
    17  	"edge-infra.dev/pkg/lib/fog"
    18  
    19  	"github.com/golang-jwt/jwt"
    20  	"github.com/google/uuid"
    21  	"gotest.tools/v3/assert"
    22  	"sigs.k8s.io/controller-runtime/pkg/client"
    23  	"sigs.k8s.io/controller-runtime/pkg/client/fake"
    24  )
    25  
    26  var (
    27  	data = []byte("my-data")
    28  	cfg  = &Config{
    29  		Namespace: edgeencrypt.DecryptionName,
    30  		JWTSecret: edgeencrypt.DecryptionJWTSecret,
    31  		KmsKey: edgeencrypt.KmsKey{
    32  			Location:  "us-central1",
    33  			ProjectID: "my-project",
    34  		},
    35  	}
    36  	server *Server
    37  	cl     client.Client
    38  )
    39  
    40  func TestMain(m *testing.M) {
    41  	cl = fake.NewFakeClient()
    42  	server = NewDecryptionServer(cfg, cl, fog.New(), nil)
    43  	os.Exit(m.Run())
    44  }
    45  
    46  func TestHealth(t *testing.T) {
    47  	w := httptest.NewRecorder()
    48  	req, err := http.NewRequest("GET", "/health", nil)
    49  	assert.NilError(t, err)
    50  	server.router.ServeHTTP(w, req)
    51  
    52  	assert.Equal(t, 200, w.Code)
    53  	assert.Equal(t, "ok", w.Body.String())
    54  }
    55  
    56  func TestDecryptUnAuthorize(t *testing.T) {
    57  	w := httptest.NewRecorder()
    58  
    59  	req, err := http.NewRequest("POST", "/decrypt", strings.NewReader(string(data)))
    60  	assert.NilError(t, err)
    61  	server.router.ServeHTTP(w, req)
    62  
    63  	// no bearer token
    64  	assert.Equal(t, 401, w.Code)
    65  }
    66  
    67  func TestEncryptionSuccess(t *testing.T) {
    68  	privateKey, pem := createPublicPrivateKey(t)
    69  
    70  	// encrypt data to decrypt for testing
    71  	encryptedData, err := edgeencrypt.EncryptData(pem, data)
    72  	assert.NilError(t, err)
    73  
    74  	e := &edgeencrypt.EncryptedData{
    75  		BannerEdgeID: "my-banner",
    76  		Channel:      "my-channel",
    77  		ChannelID:    uuid.NewString(),
    78  		KeyVersion:   "1",
    79  		Data:         encryptedData,
    80  	}
    81  	encryptedData, err = e.ToEncryptionResponse()
    82  	assert.NilError(t, err)
    83  
    84  	w := httptest.NewRecorder()
    85  	req, err := http.NewRequest("POST", "/v1/decrypt/"+e.Channel, strings.NewReader(string(encryptedData)))
    86  	assert.NilError(t, err)
    87  
    88  	token, err := edgeencrypt.CreateToken(jwt.SigningMethodRS256, privateKey, edgeencrypt.DefaultDuration, e.ChannelID, e.Channel, edgeencrypt.Decryption)
    89  	assert.NilError(t, err)
    90  
    91  	req.Header.Set("Authorization", "Bearer "+token)
    92  
    93  	server.decrypt = func(_ context.Context, _, _, _ string, aesKey []byte) ([]byte, error) {
    94  		return rsa.DecryptOAEP(sha256.New(), rand.Reader, privateKey, aesKey, nil)
    95  	}
    96  	server.router.ServeHTTP(w, req)
    97  
    98  	assert.Equal(t, 200, w.Code)
    99  	body, err := io.ReadAll(w.Body)
   100  	assert.NilError(t, err)
   101  
   102  	assert.Equal(t, string(data), string(body))
   103  }
   104  
   105  func createPublicPrivateKey(t *testing.T) (*rsa.PrivateKey, string) {
   106  	privateKey, err := rsa.GenerateKey(rand.Reader, edgeencrypt.RSA2048)
   107  	assert.NilError(t, err)
   108  
   109  	publicKey := &privateKey.PublicKey
   110  
   111  	pem, err := edgeencrypt.ConvertRSAPublicKeyToPEM(publicKey)
   112  	assert.NilError(t, err)
   113  
   114  	ee := &edgeencrypt.PublicKey{Version: "1", PEM: pem}
   115  
   116  	// public key to validate bearer token
   117  	err = ee.Save(context.Background(), cl, edgeencrypt.DecryptionName, edgeencrypt.DecryptionJWTSecret)
   118  	assert.NilError(t, err)
   119  
   120  	// public key to encrypt data
   121  	channelEncryptionSecret := fmt.Sprintf(edgeencrypt.EncryptionSecret, "my-channel")
   122  	err = ee.Save(context.Background(), cl, edgeencrypt.DecryptionName, channelEncryptionSecret)
   123  	assert.NilError(t, err)
   124  
   125  	return privateKey, pem
   126  }
   127  

View as plain text