1 package spiffe
2
3 import (
4 "context"
5 "crypto/x509"
6 "encoding/pem"
7 "strings"
8 "testing"
9
10 intoto "github.com/in-toto/in-toto-golang/in_toto"
11 "github.com/in-toto/in-toto-golang/internal/test"
12 "github.com/spiffe/go-spiffe/v2/spiffeid"
13 "github.com/spiffe/go-spiffe/v2/svid/x509svid"
14 "github.com/stretchr/testify/assert"
15 "github.com/stretchr/testify/require"
16 )
17
18 var (
19 td = spiffeid.RequireTrustDomainFromString("example.org")
20 fooID = spiffeid.RequireFromPath(td, "/foo")
21 )
22
23 func assertX509SVID(tb testing.TB, sd SVIDDetails, spiffeID spiffeid.ID, certificates []*x509.Certificate, intermediates []*x509.Certificate) {
24 assert.NotEmpty(tb, spiffeID)
25 assert.Equal(tb, certificates[0], sd.Certificate)
26 assert.Equal(tb, intermediates, sd.Intermediates)
27 assert.NotEmpty(tb, sd.PrivateKey)
28 }
29
30 func assertInTotoKey(tb testing.TB, key intoto.Key, svid *x509svid.SVID) {
31 assert.NotNil(tb, key.KeyID, "keyID is empty.")
32 assert.Equal(tb, []string{"sha256", "sha512"}, key.KeyIDHashAlgorithms)
33 assert.Equal(tb, "ecdsa", key.KeyType)
34 assert.Equal(tb, "ecdsa-sha2-nistp256", key.Scheme)
35 cerBytes, keyBytes, _ := svid.Marshal()
36 keyData, _ := pem.Decode(keyBytes)
37 certData, _ := pem.Decode(cerBytes)
38 assert.Equal(tb, strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{Bytes: keyData.Bytes, Type: "PRIVATE KEY"}))), key.KeyVal.Private)
39 privKey, _ := x509.ParseCertificate(certData.Bytes)
40 pubKeyBytes, _ := x509.MarshalPKIXPublicKey(privKey.PublicKey)
41 assert.Equal(tb, strings.TrimSpace(string(pem.EncodeToMemory(&pem.Block{Bytes: pubKeyBytes, Type: "PUBLIC KEY"}))), key.KeyVal.Public)
42 assert.Equal(tb, string(pem.EncodeToMemory(&pem.Block{Bytes: svid.Certificates[0].Raw, Type: "CERTIFICATE"})), key.KeyVal.Certificate)
43
44 }
45
46 func makeX509SVIDs(ca *test.CA, ids ...spiffeid.ID) []*x509svid.SVID {
47 svids := []*x509svid.SVID{}
48 for _, id := range ids {
49 svids = append(svids, ca.CreateX509SVID(id))
50 }
51 return svids
52 }
53
54 func getSVIDs(t *testing.T, badInput bool) *test.X509SVIDResponse {
55 ca := test.NewCA(t, td)
56 var svids []*x509svid.SVID
57 if badInput {
58 svids = makeX509SVIDsNoPrivateKey(ca, fooID)
59 } else {
60 svids = makeX509SVIDs(ca, fooID)
61 }
62
63 resp := &test.X509SVIDResponse{
64 Bundle: ca.X509Bundle(),
65 SVIDs: svids,
66 }
67 return resp
68 }
69
70 func makeX509SVIDsNoPrivateKey(ca *test.CA, ids ...spiffeid.ID) []*x509svid.SVID {
71 svids := []*x509svid.SVID{}
72 for _, id := range ids {
73 svids = append(svids, ca.CreateX509SVIDNoPrivateKey(id))
74 }
75 return svids
76 }
77
78 func TestNewClient(t *testing.T) {
79
80 wl := test.NewWorkloadAPI(t)
81 defer wl.Stop()
82 spireClient, err := NewClient(context.Background(), wl.Addr())
83 require.NoError(t, err)
84 defer spireClient.Close()
85 assert.Nil(t, err, "Unexpected error!")
86 assert.NotNil(t, spireClient, "Unexpected error getting client")
87 }
88
89 func TestGetSVIDNoPrivateKey(t *testing.T) {
90
91 wl := test.NewWorkloadAPI(t)
92 defer wl.Stop()
93 spireClient, err := NewClient(context.Background(), wl.Addr())
94 require.NoError(t, err)
95 defer spireClient.Close()
96 resp := getSVIDs(t, true)
97 wl.SetX509SVIDResponse(resp)
98
99 svidDetail, err := GetSVID(context.Background(), spireClient)
100 assert.Equal(t, SVIDDetails{PrivateKey: nil, Certificate: nil, Intermediates: nil}, svidDetail)
101 assert.Error(t, err)
102 }
103
104 func TestGetSVID(t *testing.T) {
105 wl := test.NewWorkloadAPI(t)
106 defer wl.Stop()
107 spireClient, err := NewClient(context.Background(), wl.Addr())
108 require.NoError(t, err)
109 defer spireClient.Close()
110
111 resp := getSVIDs(t, false)
112 wl.SetX509SVIDResponse(resp)
113
114 svidDetail, err := GetSVID(context.Background(), spireClient)
115 require.NoError(t, err)
116 assertX509SVID(t, svidDetail, fooID, resp.SVIDs[0].Certificates, resp.SVIDs[0].Certificates[1:])
117 }
118
119 func TestSVIDDetails_IntotoKey(t *testing.T) {
120 wl := test.NewWorkloadAPI(t)
121 defer wl.Stop()
122 spireClient, err := NewClient(context.Background(), wl.Addr())
123 require.NoError(t, err)
124 defer spireClient.Close()
125
126 resp := getSVIDs(t, false)
127 wl.SetX509SVIDResponse(resp)
128
129 svidDetail, err := GetSVID(context.Background(), spireClient)
130
131 require.NoError(t, err)
132
133 key, err := svidDetail.InTotoKey()
134 assert.Nil(t, err, "Unexpected error!")
135 assertInTotoKey(t, key, resp.SVIDs[0])
136 }
137
138 func TestSVIDDetails_BadIntotoKey(t *testing.T) {
139 wl := test.NewWorkloadAPI(t)
140 defer wl.Stop()
141 spireClient, err := NewClient(context.Background(), wl.Addr())
142 require.NoError(t, err)
143 defer spireClient.Close()
144
145 resp := getSVIDs(t, false)
146 wl.SetX509SVIDResponse(resp)
147
148 svidDetail, err := GetSVID(context.Background(), spireClient)
149
150 require.NoError(t, err)
151
152 svidDetail.PrivateKey = nil
153
154 key, err := svidDetail.InTotoKey()
155 assert.Equal(t, intoto.Key{KeyID: "", KeyIDHashAlgorithms: nil, KeyType: "",
156 Scheme: "", KeyVal: intoto.KeyVal{Private: "",
157 Public: "", Certificate: ""}}, key)
158 assert.Error(t, err)
159 }
160
View as plain text