1 package sparrow
2
3 import (
4 "context"
5 "crypto/rand"
6 "crypto/rsa"
7 "crypto/sha256"
8 "fmt"
9 "io"
10 "net/http"
11 "net/http/httptest"
12 "os"
13 "strings"
14 "testing"
15
16 "github.com/golang-jwt/jwt"
17 "github.com/google/uuid"
18 "gotest.tools/v3/assert"
19 "sigs.k8s.io/controller-runtime/pkg/client"
20 "sigs.k8s.io/controller-runtime/pkg/client/fake"
21
22 "edge-infra.dev/pkg/edge/edgeencrypt"
23 "edge-infra.dev/pkg/lib/fog"
24 )
25
26 var (
27 data = []byte("my-data")
28 cfg = &Config{Namespace: edgeencrypt.EncryptionName, JWTSecret: edgeencrypt.EncryptionJWTSecret}
29 server *Server
30 cl client.Client
31 )
32
33 func TestMain(m *testing.M) {
34 cl = fake.NewFakeClient()
35 server = NewEncryptionServer(cfg, cl, fog.New())
36 os.Exit(m.Run())
37 }
38
39 func TestHealth(t *testing.T) {
40 w := httptest.NewRecorder()
41 req, err := http.NewRequest("GET", "/health", nil)
42 assert.NilError(t, err)
43 server.router.ServeHTTP(w, req)
44
45 assert.Equal(t, 200, w.Code)
46 assert.Equal(t, "ok", w.Body.String())
47 }
48
49 func TestEncryptUnAuthorize(t *testing.T) {
50 w := httptest.NewRecorder()
51
52 req, err := http.NewRequest("POST", "/v1/encrypt", strings.NewReader(string(data)))
53 assert.NilError(t, err)
54 server.router.ServeHTTP(w, req)
55
56
57 assert.Equal(t, 401, w.Code)
58 }
59
60 func TestEncryptionSuccess(t *testing.T) {
61 w := httptest.NewRecorder()
62 req, err := http.NewRequest("POST", "/v1/encrypt", strings.NewReader(string(data)))
63 assert.NilError(t, err)
64
65 privateKey := createPublicPrivateKey(t)
66
67 channelID := uuid.NewString()
68 channelName := "my-channel"
69 token, err := edgeencrypt.CreateToken(jwt.SigningMethodRS256, privateKey, edgeencrypt.DefaultDuration, channelID, channelName, edgeencrypt.Encryption, "my-banner")
70 assert.NilError(t, err)
71
72 req.Header.Set("Authorization", "Bearer "+token)
73
74 server.router.ServeHTTP(w, req)
75
76 assert.Equal(t, 200, w.Code)
77 body, err := io.ReadAll(w.Body)
78 assert.NilError(t, err)
79
80 e := &edgeencrypt.EncryptedData{}
81 err = e.FromDecryptionRequest(body)
82 assert.NilError(t, err)
83 assert.Equal(t, channelName, e.Channel)
84 assert.Equal(t, "1", e.KeyVersion)
85
86 ec := &edgeencrypt.EncryptionClaims{
87 ChannelID: e.ChannelID,
88 Channel: channelName,
89 Role: edgeencrypt.Decryption,
90 }
91 assert.NilError(t, e.Valid())
92
93 decryptedData, err := edgeencrypt.DecryptData(context.Background(), e, ec,
94 func(_ context.Context, _, _, _ string, aesKey []byte) ([]byte, error) {
95 return rsa.DecryptOAEP(sha256.New(), rand.Reader, privateKey, aesKey, nil)
96 })
97 assert.NilError(t, err)
98
99 assert.Equal(t, string(data), string(decryptedData))
100 }
101
102 func createPublicPrivateKey(t *testing.T) *rsa.PrivateKey {
103 privateKey, err := rsa.GenerateKey(rand.Reader, edgeencrypt.RSA2048)
104 assert.NilError(t, err)
105
106 publicKey := &privateKey.PublicKey
107
108 pem, err := edgeencrypt.ConvertRSAPublicKeyToPEM(publicKey)
109 assert.NilError(t, err)
110
111 ee := &edgeencrypt.PublicKey{Version: "1", PEM: pem}
112
113
114 err = ee.Save(context.Background(), cl, edgeencrypt.EncryptionName, edgeencrypt.EncryptionJWTSecret)
115 assert.NilError(t, err)
116
117
118 channelEncryptionSecret := fmt.Sprintf(edgeencrypt.EncryptionSecret, "my-channel")
119 err = ee.Save(context.Background(), cl, edgeencrypt.EncryptionName, channelEncryptionSecret)
120 assert.NilError(t, err)
121
122 return privateKey
123 }
124
View as plain text