1 package magpie
2
3 import (
4 "context"
5 "crypto/rand"
6 "crypto/rsa"
7 "crypto/sha256"
8 "fmt"
9 "io"
10 "net/http"
11 "net/http/httptest"
12 "os"
13 "strings"
14 "testing"
15
16 "edge-infra.dev/pkg/edge/edgeencrypt"
17 "edge-infra.dev/pkg/lib/fog"
18
19 "github.com/golang-jwt/jwt"
20 "github.com/google/uuid"
21 "gotest.tools/v3/assert"
22 "sigs.k8s.io/controller-runtime/pkg/client"
23 "sigs.k8s.io/controller-runtime/pkg/client/fake"
24 )
25
26 var (
27 data = []byte("my-data")
28 cfg = &Config{
29 Namespace: edgeencrypt.DecryptionName,
30 JWTSecret: edgeencrypt.DecryptionJWTSecret,
31 KmsKey: edgeencrypt.KmsKey{
32 Location: "us-central1",
33 ProjectID: "my-project",
34 },
35 }
36 server *Server
37 cl client.Client
38 )
39
40 func TestMain(m *testing.M) {
41 cl = fake.NewFakeClient()
42 server = NewDecryptionServer(cfg, cl, fog.New(), nil)
43 os.Exit(m.Run())
44 }
45
46 func TestHealth(t *testing.T) {
47 w := httptest.NewRecorder()
48 req, err := http.NewRequest("GET", "/health", nil)
49 assert.NilError(t, err)
50 server.router.ServeHTTP(w, req)
51
52 assert.Equal(t, 200, w.Code)
53 assert.Equal(t, "ok", w.Body.String())
54 }
55
56 func TestDecryptUnAuthorize(t *testing.T) {
57 w := httptest.NewRecorder()
58
59 req, err := http.NewRequest("POST", "/decrypt", strings.NewReader(string(data)))
60 assert.NilError(t, err)
61 server.router.ServeHTTP(w, req)
62
63
64 assert.Equal(t, 401, w.Code)
65 }
66
67 func TestEncryptionSuccess(t *testing.T) {
68 privateKey, pem := createPublicPrivateKey(t)
69
70
71 encryptedData, err := edgeencrypt.EncryptData(pem, data)
72 assert.NilError(t, err)
73
74 e := &edgeencrypt.EncryptedData{
75 BannerEdgeID: "my-banner",
76 Channel: "my-channel",
77 ChannelID: uuid.NewString(),
78 KeyVersion: "1",
79 Data: encryptedData,
80 }
81 encryptedData, err = e.ToEncryptionResponse()
82 assert.NilError(t, err)
83
84 w := httptest.NewRecorder()
85 req, err := http.NewRequest("POST", "/v1/decrypt/"+e.Channel, strings.NewReader(string(encryptedData)))
86 assert.NilError(t, err)
87
88 token, err := edgeencrypt.CreateToken(jwt.SigningMethodRS256, privateKey, edgeencrypt.DefaultDuration, e.ChannelID, e.Channel, edgeencrypt.Decryption)
89 assert.NilError(t, err)
90
91 req.Header.Set("Authorization", "Bearer "+token)
92
93 server.decrypt = func(_ context.Context, _, _, _ string, aesKey []byte) ([]byte, error) {
94 return rsa.DecryptOAEP(sha256.New(), rand.Reader, privateKey, aesKey, nil)
95 }
96 server.router.ServeHTTP(w, req)
97
98 assert.Equal(t, 200, w.Code)
99 body, err := io.ReadAll(w.Body)
100 assert.NilError(t, err)
101
102 assert.Equal(t, string(data), string(body))
103 }
104
105 func createPublicPrivateKey(t *testing.T) (*rsa.PrivateKey, string) {
106 privateKey, err := rsa.GenerateKey(rand.Reader, edgeencrypt.RSA2048)
107 assert.NilError(t, err)
108
109 publicKey := &privateKey.PublicKey
110
111 pem, err := edgeencrypt.ConvertRSAPublicKeyToPEM(publicKey)
112 assert.NilError(t, err)
113
114 ee := &edgeencrypt.PublicKey{Version: "1", PEM: pem}
115
116
117 err = ee.Save(context.Background(), cl, edgeencrypt.DecryptionName, edgeencrypt.DecryptionJWTSecret)
118 assert.NilError(t, err)
119
120
121 channelEncryptionSecret := fmt.Sprintf(edgeencrypt.EncryptionSecret, "my-channel")
122 err = ee.Save(context.Background(), cl, edgeencrypt.DecryptionName, channelEncryptionSecret)
123 assert.NilError(t, err)
124
125 return privateKey, pem
126 }
127
View as plain text