...
1 package test
2
3 import (
4 "context"
5 "crypto/x509"
6 "errors"
7 "fmt"
8 "net"
9 "sync"
10 "testing"
11
12 "github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
13 "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload"
14 "github.com/spiffe/go-spiffe/v2/svid/x509svid"
15 "github.com/stretchr/testify/require"
16 "google.golang.org/grpc"
17 "google.golang.org/grpc/codes"
18 "google.golang.org/grpc/metadata"
19 "google.golang.org/grpc/status"
20 )
21
22 var noIdentityError = status.Error(codes.PermissionDenied, "no identity issued")
23
24 type WorkloadAPI struct {
25 tb testing.TB
26 wg sync.WaitGroup
27 addr string
28 server *grpc.Server
29 mu sync.Mutex
30 x509Resp *workload.X509SVIDResponse
31 x509Chans map[chan *workload.X509SVIDResponse]struct{}
32 }
33
34 func NewWorkloadAPI(tb testing.TB) *WorkloadAPI {
35 w := &WorkloadAPI{
36 x509Chans: make(map[chan *workload.X509SVIDResponse]struct{}),
37 }
38
39 listener, err := net.Listen("tcp", "localhost:0")
40 require.NoError(tb, err)
41
42 server := grpc.NewServer()
43 workload.RegisterSpiffeWorkloadAPIServer(server, &workloadAPIWrapper{w: w})
44
45 w.wg.Add(1)
46 go func() {
47 defer w.wg.Done()
48 _ = server.Serve(listener)
49 }()
50
51 w.addr = fmt.Sprintf("%s://%s", listener.Addr().Network(), listener.Addr().String())
52 tb.Logf("WorkloadAPI address: %s", w.addr)
53 w.server = server
54 return w
55 }
56
57 func (w *WorkloadAPI) Stop() {
58 w.server.Stop()
59 w.wg.Wait()
60 }
61
62 func (w *WorkloadAPI) Addr() string {
63 return w.addr
64 }
65
66 func (w *WorkloadAPI) SetX509SVIDResponse(r *X509SVIDResponse) {
67 var resp *workload.X509SVIDResponse
68 if r != nil {
69 resp = r.ToProto(w.tb)
70 }
71
72 w.mu.Lock()
73 defer w.mu.Unlock()
74 w.x509Resp = resp
75
76 for ch := range w.x509Chans {
77 select {
78 case ch <- resp:
79 default:
80 <-ch
81 ch <- resp
82 }
83 }
84 }
85
86 func concatRawCertsFromCerts(certs []*x509.Certificate) []byte {
87 var rawCerts []byte
88 for _, cert := range certs {
89 rawCerts = append(rawCerts, cert.Raw...)
90 }
91 return rawCerts
92 }
93
94 func (r *X509SVIDResponse) ToProto(tb testing.TB) *workload.X509SVIDResponse {
95 var bundle []byte
96 if r.Bundle != nil {
97 bundle = concatRawCertsFromCerts(r.Bundle.X509Authorities())
98 }
99
100 pb := &workload.X509SVIDResponse{
101 FederatedBundles: make(map[string][]byte),
102 }
103 for _, svid := range r.SVIDs {
104 var keyDER []byte
105 if svid.PrivateKey != nil {
106 var err error
107 keyDER, err = x509.MarshalPKCS8PrivateKey(svid.PrivateKey)
108 require.NoError(tb, err)
109 }
110 pb.Svids = append(pb.Svids, &workload.X509SVID{
111 SpiffeId: svid.ID.String(),
112 X509Svid: concatRawCertsFromCerts(svid.Certificates),
113 X509SvidKey: keyDER,
114 Bundle: bundle,
115 })
116 }
117 for _, v := range r.FederatedBundles {
118 pb.FederatedBundles[v.TrustDomain().IDString()] = concatRawCertsFromCerts(v.X509Authorities())
119 }
120
121 return pb
122 }
123
124 type workloadAPIWrapper struct {
125 workload.UnimplementedSpiffeWorkloadAPIServer
126 w *WorkloadAPI
127 }
128
129 func (w *workloadAPIWrapper) FetchX509SVID(req *workload.X509SVIDRequest, stream workload.SpiffeWorkloadAPI_FetchX509SVIDServer) error {
130 return w.w.fetchX509SVID(req, stream)
131 }
132
133 type X509SVIDResponse struct {
134 SVIDs []*x509svid.SVID
135 Bundle *x509bundle.Bundle
136 FederatedBundles []*x509bundle.Bundle
137 }
138
139 func (w *WorkloadAPI) fetchX509SVID(_ *workload.X509SVIDRequest, stream workload.SpiffeWorkloadAPI_FetchX509SVIDServer) error {
140 if err := checkHeader(stream.Context()); err != nil {
141 return err
142 }
143 ch := make(chan *workload.X509SVIDResponse, 1)
144 w.mu.Lock()
145 w.x509Chans[ch] = struct{}{}
146 resp := w.x509Resp
147 w.mu.Unlock()
148
149 defer func() {
150 w.mu.Lock()
151 delete(w.x509Chans, ch)
152 w.mu.Unlock()
153 }()
154
155 sendResp := func(resp *workload.X509SVIDResponse) error {
156 if resp == nil {
157 return noIdentityError
158 }
159 return stream.Send(resp)
160 }
161
162 if err := sendResp(resp); err != nil {
163 return err
164 }
165 for {
166 select {
167 case resp := <-ch:
168 if err := sendResp(resp); err != nil {
169 return err
170 }
171 case <-stream.Context().Done():
172 return stream.Context().Err()
173 }
174 }
175 }
176
177 func checkHeader(ctx context.Context) error {
178 return checkMetadata(ctx, "workload.spiffe.io", "true")
179 }
180
181 func checkMetadata(ctx context.Context, key, value string) error {
182 md, ok := metadata.FromIncomingContext(ctx)
183 if !ok {
184 return errors.New("request does not contain metadata")
185 }
186 values := md.Get(key)
187 if len(value) == 0 {
188 return fmt.Errorf("request metadata does not contain %q value", key)
189 }
190 if values[0] != value {
191 return fmt.Errorf("request metadata %q value is %q; expected %q", key, values[0], value)
192 }
193 return nil
194 }
195
View as plain text